From 3af803b7ad8e7a8dc7f3e5d18e3bf51da4f66bbc Mon Sep 17 00:00:00 2001 From: sbruens Date: Mon, 23 Sep 2024 12:41:59 -0400 Subject: [PATCH] Use concrete `slog.Logger` instead of `Logger` interface now that we don't need a zap adapter for Caddy. --- service/logger.go | 19 +++---------------- service/shadowsocks.go | 7 ++++--- service/tcp.go | 32 ++++++++++++++++---------------- service/tcp_test.go | 4 ++-- service/udp.go | 36 ++++++++++++++++++------------------ service/udp_test.go | 10 +++++----- 6 files changed, 48 insertions(+), 60 deletions(-) diff --git a/service/logger.go b/service/logger.go index b751fb8e..50fb4fe1 100644 --- a/service/logger.go +++ b/service/logger.go @@ -15,23 +15,10 @@ package service import ( - "context" + "io" "log/slog" ) -type Logger interface { - Enabled(ctx context.Context, level slog.Level) bool - LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) -} - -type noopLogger struct { -} - -var _ Logger = (*noopLogger)(nil) - -func (l *noopLogger) Enabled(ctx context.Context, level slog.Level) bool { - return false -} - -func (l *noopLogger) LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) { +func noopLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError})) } diff --git a/service/shadowsocks.go b/service/shadowsocks.go index 073f434e..636fa94e 100644 --- a/service/shadowsocks.go +++ b/service/shadowsocks.go @@ -16,6 +16,7 @@ package service import ( "context" + "log/slog" "net" "time" @@ -50,7 +51,7 @@ type Service interface { type Option func(s *ssService) type ssService struct { - logger Logger + logger *slog.Logger metrics ServiceMetrics ciphers CipherList natTimeout time.Duration @@ -74,7 +75,7 @@ func NewShadowsocksService(opts ...Option) (Service, error) { } // If no logger is provided via options, use a noop logger. if s.logger == nil { - s.logger = &noopLogger{} + s.logger = noopLogger() } // TODO: Register initial data metrics at zero. @@ -92,7 +93,7 @@ func NewShadowsocksService(opts ...Option) (Service, error) { // WithLogger can be used to provide a custom log target. If not provided, // the service uses a noop logger (i.e., no logging). -func WithLogger(l Logger) Option { +func WithLogger(l *slog.Logger) Option { return func(s *ssService) { s.logger = l } diff --git a/service/tcp.go b/service/tcp.go index 44473df8..e61f0aa3 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -58,7 +58,7 @@ func remoteIP(conn net.Conn) netip.Addr { } // Wrapper for slog.Debug during TCP access key searches. -func debugTCP(l Logger, template string, cipherID string, attr slog.Attr) { +func debugTCP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction // between Go's inlining/escape analysis and varargs functions like slog.Debug. if l != nil && l.Enabled(nil, slog.LevelDebug) { @@ -72,7 +72,7 @@ func debugTCP(l Logger, template string, cipherID string, attr slog.Attr) { // required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection. const bytesForKeyFinding = 50 -func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList, l Logger) (*CipherEntry, io.Reader, []byte, time.Duration, error) { +func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList, l *slog.Logger) (*CipherEntry, io.Reader, []byte, time.Duration, error) { // We snapshot the list because it may be modified while we use it. ciphers := cipherList.SnapshotForClientIP(clientIP) firstBytes := make([]byte, bytesForKeyFinding) @@ -95,7 +95,7 @@ func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList Ciphe } // Implements a trial decryption search. This assumes that all ciphers are AEAD. -func findEntry(firstBytes []byte, ciphers []*list.Element, l Logger) (*CipherEntry, *list.Element) { +func findEntry(firstBytes []byte, ciphers []*list.Element, l *slog.Logger) (*CipherEntry, *list.Element) { // To hold the decrypted chunk length. chunkLenBuf := [2]byte{} for ci, elt := range ciphers { @@ -116,12 +116,12 @@ type StreamAuthenticateFunc func(clientConn transport.StreamConn) (string, trans // NewShadowsocksStreamAuthenticator creates a stream authenticator that uses Shadowsocks. // TODO(fortuna): Offer alternative transports. -func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksConnMetrics, l Logger) StreamAuthenticateFunc { +func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksConnMetrics, l *slog.Logger) StreamAuthenticateFunc { if metrics == nil { metrics = &NoOpShadowsocksConnMetrics{} } if l == nil { - l = &noopLogger{} + l = noopLogger() } return func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError) { // Find the cipher and acess key id. @@ -157,7 +157,7 @@ func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCa } type streamHandler struct { - l Logger + logger *slog.Logger listenerId string readTimeout time.Duration authenticate StreamAuthenticateFunc @@ -167,7 +167,7 @@ type streamHandler struct { // NewStreamHandler creates a StreamHandler func NewStreamHandler(authenticate StreamAuthenticateFunc, timeout time.Duration) StreamHandler { return &streamHandler{ - l: &noopLogger{}, + logger: noopLogger(), readTimeout: timeout, authenticate: authenticate, dialer: defaultDialer, @@ -187,16 +187,16 @@ func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) tra type StreamHandler interface { Handle(ctx context.Context, conn transport.StreamConn, connMetrics TCPConnMetrics) // SetLogger sets the logger used to log messages. Uses a no-op logger if nil. - SetLogger(l Logger) + SetLogger(l *slog.Logger) // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses. SetTargetDialer(dialer transport.StreamDialer) } -func (s *streamHandler) SetLogger(l Logger) { +func (s *streamHandler) SetLogger(l *slog.Logger) { if l == nil { - l = &noopLogger{} + l = noopLogger() } - s.l = l + s.logger = l } func (s *streamHandler) SetTargetDialer(dialer transport.StreamDialer) { @@ -271,11 +271,11 @@ func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamC status := "OK" if connError != nil { status = connError.Status - h.l.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) + h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) } connMetrics.AddClosed(status, proxyMetrics, connDuration) measuredClientConn.Close() // Closing after the metrics are added aids integration testing. - h.l.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration)) + h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration)) } func getProxyRequest(clientConn transport.StreamConn) (string, error) { @@ -290,7 +290,7 @@ func getProxyRequest(clientConn transport.StreamConn) (string, error) { return tgtAddr.String(), nil } -func proxyConnection(l Logger, ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError { +func proxyConnection(l *slog.Logger, ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError { tgtConn, dialErr := dialer.DialStream(ctx, tgtAddr) if dialErr != nil { // We don't drain so dial errors and invalid addresses are communicated quickly. @@ -365,7 +365,7 @@ func (h *streamHandler) handleConnection(ctx context.Context, outerConn transpor tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy) return tgtConn, nil }) - return proxyConnection(h.l, ctx, dialer, tgtAddr, innerConn) + return proxyConnection(h.logger, ctx, dialer, tgtAddr, innerConn) } // Keep the connection open until we hit the authentication deadline to protect against probing attacks @@ -374,7 +374,7 @@ func (h *streamHandler) absorbProbe(clientConn io.ReadCloser, connMetrics TCPCon // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe. _, drainErr := io.Copy(io.Discard, clientConn) // drain socket drainResult := drainErrToString(drainErr) - h.l.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult)) + h.logger.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult)) connMetrics.AddProbe(status, drainResult, proxyMetrics.ClientProxy) } diff --git a/service/tcp_test.go b/service/tcp_test.go index e8f234f1..e69d1d1e 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -102,7 +102,7 @@ func BenchmarkTCPFindCipherFail(b *testing.B) { } clientIP := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr() b.StartTimer() - findAccessKey(clientConn, clientIP, cipherList, &noopLogger{}) + findAccessKey(clientConn, clientIP, cipherList, noopLogger()) b.StopTimer() } } @@ -205,7 +205,7 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) { cipher := cipherEntries[cipherNumber].CryptoKey go shadowsocks.NewWriter(writer, cipher).Write(makeTestPayload(50)) b.StartTimer() - _, _, _, _, err := findAccessKey(&c, clientIP, cipherList, &noopLogger{}) + _, _, _, _, err := findAccessKey(&c, clientIP, cipherList, noopLogger()) b.StopTimer() if err != nil { b.Error(err) diff --git a/service/udp.go b/service/udp.go index df034b15..8ff5352f 100644 --- a/service/udp.go +++ b/service/udp.go @@ -44,7 +44,7 @@ type UDPMetrics interface { const serverUDPBufferSize = 64 * 1024 // Wrapper for slog.Debug during UDP proxying. -func debugUDP(l Logger, template string, cipherID string, attr slog.Attr) { +func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction // between Go's inlining/escape analysis and varargs functions like slog.Debug. if l.Enabled(nil, slog.LevelDebug) { @@ -52,7 +52,7 @@ func debugUDP(l Logger, template string, cipherID string, attr slog.Attr) { } } -func debugUDPAddr(l Logger, template string, addr net.Addr, attr slog.Attr) { +func debugUDPAddr(l *slog.Logger, template string, addr net.Addr, attr slog.Attr) { if l.Enabled(nil, slog.LevelDebug) { l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("address", addr.String()), attr) } @@ -60,7 +60,7 @@ func debugUDPAddr(l Logger, template string, addr net.Addr, attr slog.Attr) { // Decrypts src into dst. It tries each cipher until it finds one that authenticates // correctly. dst and src must not overlap. -func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList, l Logger) ([]byte, string, *shadowsocks.EncryptionKey, error) { +func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList, l *slog.Logger) ([]byte, string, *shadowsocks.EncryptionKey, error) { // Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD. // We snapshot the list because it may be modified while we use it. snapshot := cipherList.SnapshotForClientIP(clientIP) @@ -80,7 +80,7 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis } type packetHandler struct { - l Logger + logger *slog.Logger natTimeout time.Duration ciphers CipherList m UDPMetrics @@ -97,7 +97,7 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr ssMetrics = &NoOpShadowsocksConnMetrics{} } return &packetHandler{ - l: &noopLogger{}, + logger: noopLogger(), natTimeout: natTimeout, ciphers: cipherList, m: m, @@ -109,18 +109,18 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr // PacketHandler is a running UDP shadowsocks proxy that can be stopped. type PacketHandler interface { // SetLogger sets the logger used to log messages. Uses a no-op logger if nil. - SetLogger(l Logger) + SetLogger(l *slog.Logger) // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) // Handle returns after clientConn closes and all the sub goroutines return. Handle(clientConn net.PacketConn) } -func (h *packetHandler) SetLogger(l Logger) { +func (h *packetHandler) SetLogger(l *slog.Logger) { if l == nil { - l = &noopLogger{} + l = noopLogger() } - h.l = l + h.logger = l } func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { @@ -132,7 +132,7 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali func (h *packetHandler) Handle(clientConn net.PacketConn) { var running sync.WaitGroup - nm := newNATmap(h.natTimeout, h.m, &running, h.l) + nm := newNATmap(h.natTimeout, h.m, &running, h.logger) defer nm.Close() cipherBuf := make([]byte, serverUDPBufferSize) textBuf := make([]byte, serverUDPBufferSize) @@ -160,7 +160,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { return onet.NewConnectionError("ERR_READ", "Failed to read from client", err) } defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAddr.String())) - debugUDPAddr(h.l, "Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes)) + debugUDPAddr(h.logger, "Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes)) cipherData := cipherBuf[:clientProxyBytes] var payload []byte @@ -171,7 +171,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { var textData []byte var cryptoKey *shadowsocks.EncryptionKey unpackStart := time.Now() - textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.l) + textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.logger) timeToCipher := time.Since(unpackStart) h.ssm.AddCipherSearch(err == nil, timeToCipher) @@ -208,7 +208,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { } } - debugUDPAddr(h.l, "Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr())) + debugUDPAddr(h.logger, "Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr())) proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) @@ -317,14 +317,14 @@ func (c *natconn) ReadFrom(buf []byte) (int, net.Addr, error) { type natmap struct { sync.RWMutex keyConn map[string]*natconn - l Logger + logger *slog.Logger timeout time.Duration metrics UDPMetrics running *sync.WaitGroup } -func newNATmap(timeout time.Duration, sm UDPMetrics, running *sync.WaitGroup, l Logger) *natmap { - m := &natmap{l: l, metrics: sm, running: running} +func newNATmap(timeout time.Duration, sm UDPMetrics, running *sync.WaitGroup, l *slog.Logger) *natmap { + m := &natmap{logger: l, metrics: sm, running: running} m.keyConn = make(map[string]*natconn) m.timeout = timeout return m @@ -370,7 +370,7 @@ func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey * m.running.Add(1) go func() { - timedCopy(clientAddr, clientConn, entry, keyID, m.l) + timedCopy(clientAddr, clientConn, entry, keyID, m.logger) connMetrics.RemoveNatEntry() if pc := m.del(clientAddr.String()); pc != nil { pc.Close() @@ -399,7 +399,7 @@ func (m *natmap) Close() error { var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // copy from target to client until read timeout -func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, keyID string, l Logger) { +func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, keyID string, l *slog.Logger) { // pkt is used for in-place encryption of downstream UDP packets, with the layout // [padding?][salt][address][body][tag][extra] // Padding is only used if the address is IPv4. diff --git a/service/udp_test.go b/service/udp_test.go index 0d62a02a..ae792363 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -207,14 +207,14 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) { } func TestNATEmpty(t *testing.T) { - nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, &noopLogger{}) + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, noopLogger()) if nat.Get("foo") != nil { t.Error("Expected nil value from empty NAT map") } } func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) { - nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, &noopLogger{}) + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, noopLogger()) clientConn := makePacketConn() targetConn := makePacketConn() nat.Add(&clientAddr, clientConn, natCryptoKey, targetConn, "key id") @@ -409,7 +409,7 @@ func BenchmarkUDPUnpackFail(b *testing.B) { testIP := netip.MustParseAddr("192.0.2.1") b.ResetTimer() for n := 0; n < b.N; n++ { - findAccessKeyUDP(testIP, textBuf, testPayload, cipherList, &noopLogger{}) + findAccessKeyUDP(testIP, textBuf, testPayload, cipherList, noopLogger()) } } @@ -439,7 +439,7 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) { cipherNumber := n % numCiphers ip := ips[cipherNumber] packet := packets[cipherNumber] - _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList, &noopLogger{}) + _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList, noopLogger()) if err != nil { b.Error(err) } @@ -468,7 +468,7 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { ip := ips[n%numIPs] - _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList, &noopLogger{}) + _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList, noopLogger()) if err != nil { b.Error(err) }