diff --git a/fixtures/config.toml b/fixtures/config.toml index 7109c50..172d7e8 100644 --- a/fixtures/config.toml +++ b/fixtures/config.toml @@ -28,8 +28,7 @@ invalid-usernames = ["用户名"] invalid-username-message = "Invalid username %s. Please check https://vlab.ustc.edu.cn/docs/login/ssh/#username for more information." [logger] -enabled = true -endpoint = "udp://127.0.0.1:5556" +enabled = false [proxy-protocol] enabled = true diff --git a/sshmux.go b/sshmux.go index 39a9610..369dc0d 100644 --- a/sshmux.go +++ b/sshmux.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "fmt" + "io" "log" "log/slog" "net" @@ -28,7 +29,7 @@ type Server struct { Banner string SSHConfig *ssh.ServerConfig Authenticator Authenticator - LogWriter *net.Conn + LogWriter io.Writer ProxyPolicy ProxyPolicyConfig UsernamePolicy UsernamePolicyConfig PasswordPolicy PasswordPolicyConfig @@ -78,7 +79,7 @@ func makeServer(config Config) (*Server, error) { if err != nil { return nil, err } - var loggerEndpoint *net.Conn = nil + var logWriter io.Writer if config.Logger.Enabled { loggerURL, err := url.Parse(config.Logger.Endpoint) if err != nil { @@ -89,17 +90,19 @@ func makeServer(config Config) (*Server, error) { if err != nil { log.Fatalf("Logger Dial failed: %s\n", err) } - loggerEndpoint = &conn + logWriter = conn } else { log.Fatalf("unsupported logger endpoint: %s\n", config.Logger.Endpoint) } + } else { + logWriter = io.Discard } sshmux := &Server{ Address: config.Address, Banner: config.SSH.Banner, SSHConfig: sshConfig, Authenticator: makeAuthenticator(config.Auth, config.Recovery), - LogWriter: loggerEndpoint, + LogWriter: logWriter, ProxyPolicy: proxyPolicyConfig, UsernamePolicy: UsernamePolicyConfig{ InvalidUsernames: config.Auth.InvalidUsernames, @@ -145,10 +148,7 @@ func (s *Server) handler(conn net.Conn) { } defer session.Close() - var logger *slog.Logger = nil - if s.LogWriter != nil { - logger = slog.New(slog.NewJSONHandler(*s.LogWriter, nil)) - } + logger := slog.New(slog.NewJSONHandler(s.LogWriter, nil)) logger = logger.With( slog.Int64("connect_time", time.Now().Unix()), slog.String("remote_ip", conn.RemoteAddr().String()),