diff --git a/conn.go b/conn.go index 04b12371..f3b9bb21 100644 --- a/conn.go +++ b/conn.go @@ -1,6 +1,8 @@ package soju import ( + "bufio" + "bytes" "context" "fmt" "io" @@ -10,6 +12,7 @@ import ( "time" "unicode" + "github.com/DataDog/zstd" "golang.org/x/time/rate" "gopkg.in/irc.v3" "nhooyr.io/websocket" @@ -25,14 +28,16 @@ type ircConn interface { SetWriteDeadline(time.Time) error RemoteAddr() net.Addr LocalAddr() net.Addr + SupportsCompression() bool + EnableReadCompression() error + EnableWriteCompression() error } func newNetIRCConn(c net.Conn) ircConn { - type netConn net.Conn - return struct { - *irc.Conn - netConn - }{irc.NewConn(c), c} + return &tcpIRCConn{ + Conn: c, + r: bufio.NewReader(c), + } } type websocketIRCConn struct { @@ -109,6 +114,18 @@ func (wic *websocketIRCConn) LocalAddr() net.Addr { return websocketAddr("") } +func (wic websocketIRCConn) SupportsCompression() bool { + return false +} + +func (wic websocketIRCConn) EnableReadCompression() error { + return fmt.Errorf("websocket: compression is unsupported") +} + +func (wic websocketIRCConn) EnableWriteCompression() error { + return fmt.Errorf("websocket: compression is unsupported") +} + type websocketAddr string func (websocketAddr) Network() string { @@ -119,6 +136,74 @@ func (wa websocketAddr) String() string { return string(wa) } +type tcpIRCConn struct { + net.Conn + wz *zstd.Writer + rz io.ReadCloser + r *bufio.Reader +} + +func (tic *tcpIRCConn) ReadMessage() (msg *irc.Message, err error) { + err = irc.ErrZeroLengthMessage + for err == irc.ErrZeroLengthMessage { + var line string + line, err = tic.r.ReadString('\n') + if err != nil { + return nil, err + } + msg, err = irc.ParseMessage(line) + } + return msg, err +} + +func (tic *tcpIRCConn) WriteMessage(msg *irc.Message) error { + data := []byte(msg.String() + "\r\n") + if tic.wz != nil { + _, err := tic.wz.Write(data) + if err != nil { + return err + } + return tic.wz.Flush() + } + _, err := tic.Conn.Write(data) + return err +} + +func (tic *tcpIRCConn) Close() error { + if tic.wz != nil { + tic.wz.Close() + } + if tic.rz != nil { + tic.rz.Close() + } + return tic.Conn.Close() +} + +func (tic *tcpIRCConn) SupportsCompression() bool { + return true +} + +func (tic *tcpIRCConn) EnableReadCompression() error { + if tic.rz == nil { + tic.rz = zstd.NewReader(tic) + rem, err := tic.r.Peek(tic.r.Buffered()) + if err != nil { + return err + } + remRd := bytes.NewReader(rem) + mr := io.MultiReader(remRd, tic.rz) + tic.r = bufio.NewReader(mr) + } + return nil +} + +func (tic *tcpIRCConn) EnableWriteCompression() error { + if tic.wz == nil { + tic.wz = zstd.NewWriterLevel(tic, 1) + } + return nil +} + type connOptions struct { Logger Logger RateLimitDelay time.Duration @@ -162,6 +247,12 @@ func newConn(srv *Server, ic ircConn, options *connOptions) *conn { c.logger.Printf("failed to write message: %v", err) break } + if msg.Command == "COMPRESS" { + if err := c.conn.EnableWriteCompression(); err != nil { + c.logger.Printf("failed to enable compression: %v", err) + break + } + } } if err := c.conn.Close(); err != nil && !isErrClosed(err) { c.logger.Printf("failed to close connection: %v", err) @@ -207,6 +298,11 @@ func (c *conn) ReadMessage() (*irc.Message, error) { } else if err != nil { return nil, err } + if msg.Command == "COMPRESS" && c.conn.SupportsCompression() { + if err := c.conn.EnableReadCompression(); err != nil { + return nil, err + } + } c.logger.Debugf("received: %v", msg) return msg, nil diff --git a/downstream.go b/downstream.go index 30ef902f..b454c3fa 100644 --- a/downstream.go +++ b/downstream.go @@ -337,6 +337,9 @@ func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn { if srv.Config().LogPath != "" { dc.supportedCaps["draft/chathistory"] = "" } + if dc.conn.conn.SupportsCompression() { + dc.supportedCaps["draft/compression"] = "" + } return dc } @@ -796,11 +799,23 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir dc.networkName = match.GetName() } + case "COMPRESS": + // handled in dc.conn.ReadMessage() if supported + if !dc.conn.conn.SupportsCompression() { + return newUnknownCommandError(msg.Command) + } default: dc.logger.Printf("unhandled message: %v", msg) return newUnknownCommandError(msg.Command) } if dc.rawUsername != "" && dc.nick != "*" && !dc.negotiatingCaps { + if dc.caps["draft/compression"] { + // triggers compression at dc.conn level + dc.SendMessage(&irc.Message{ + Prefix: dc.srv.prefix(), + Command: "COMPRESS", + }) + } return dc.register(ctx) } return nil @@ -2843,6 +2858,11 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc. }) } }) + case "COMPRESS": + // handled in dc.conn.ReadMessage() if supported + if !dc.conn.conn.SupportsCompression() { + return newUnknownCommandError(msg.Command) + } case "BOUNCER": var subcommand string if err := parseMessageParams(msg, &subcommand); err != nil { diff --git a/go.mod b/go.mod index 462a98bf..ab9c4a02 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.15 require ( git.sr.ht/~emersion/go-scfg v0.0.0-20201019143924-142a8aa629fc git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 + github.com/DataDog/zstd v1.5.0 github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/emersion/go-sasl v0.0.0-20211008083017-0b9dcfb154ac github.com/golang/protobuf v1.5.2 // indirect diff --git a/go.sum b/go.sum index 8c11e230..f5bbf08a 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMA git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/DataDog/zstd v1.5.0 h1:+K/VEwIAaPcHiMtQvpLD4lqW7f0Gk3xdYZmI1hD+CXo= +github.com/DataDog/zstd v1.5.0/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=