diff --git a/client/client.go b/client/client.go index fd503ff6e..792ae4a61 100644 --- a/client/client.go +++ b/client/client.go @@ -1,6 +1,7 @@ package client import ( + "context" "net" "github.com/9seconds/mtg/config" @@ -9,4 +10,5 @@ import ( ) // Init defines common method for initializing client connections. -type Init func(net.Conn, string, *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) +type Init func(context.Context, context.CancelFunc, net.Conn, string, + *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) diff --git a/client/direct.go b/client/direct.go index 01730bdec..6545b9ec5 100644 --- a/client/direct.go +++ b/client/direct.go @@ -1,6 +1,7 @@ package client import ( + "context" "net" "time" @@ -16,7 +17,8 @@ const handshakeTimeout = 10 * time.Second // DirectInit initializes client connection for proxy which connects to // Telegram directly. -func DirectInit(socket net.Conn, connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) { +func DirectInit(ctx context.Context, cancel context.CancelFunc, socket net.Conn, + connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) { tcpSocket := socket.(*net.TCPConn) if err := tcpSocket.SetNoDelay(false); err != nil { return nil, nil, errors.Annotate(err, "Cannot disable NO_DELAY to client socket") @@ -35,7 +37,7 @@ func DirectInit(socket net.Conn, connID string, conf *config.Config) (wrappers.W } socket.SetReadDeadline(time.Time{}) // nolint: errcheck - conn := wrappers.NewConn(socket, connID, wrappers.ConnPurposeClient, conf.PublicIPv4, conf.PublicIPv6) + conn := wrappers.NewConn(ctx, cancel, socket, connID, wrappers.ConnPurposeClient, conf.PublicIPv4, conf.PublicIPv6) obfs2, connOpts, err := obfuscated2.ParseObfuscated2ClientFrame(conf.Secret, frame) if err != nil { return nil, nil, errors.Annotate(err, "Cannot parse obfuscated frame") diff --git a/client/middle.go b/client/middle.go index 4e5089557..4c41306c7 100644 --- a/client/middle.go +++ b/client/middle.go @@ -1,6 +1,7 @@ package client import ( + "context" "net" "github.com/9seconds/mtg/config" @@ -10,8 +11,9 @@ import ( // MiddleInit initializes client connection for proxy which has to // support promoted channels, connect to Telegram middle proxies etc. -func MiddleInit(socket net.Conn, connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) { - conn, opts, err := DirectInit(socket, connID, conf) +func MiddleInit(ctx context.Context, cancel context.CancelFunc, socket net.Conn, + connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) { + conn, opts, err := DirectInit(ctx, cancel, socket, connID, conf) if err != nil { return nil, nil, err } diff --git a/proxy/proxy.go b/proxy/proxy.go index 7a26bb71f..182dc0e46 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "io" "net" "sync" @@ -43,6 +44,7 @@ func (p *Proxy) Serve() error { func (p *Proxy) accept(conn net.Conn) { connID := uuid.NewV4().String() log := zap.S().With("connection_id", connID).Named("main") + ctx, cancel := context.WithCancel(context.Background()) defer func() { conn.Close() // nolint: errcheck @@ -55,7 +57,7 @@ func (p *Proxy) accept(conn net.Conn) { log.Infow("Client connected", "addr", conn.RemoteAddr()) - clientConn, opts, err := p.clientInit(conn, connID, p.conf) + clientConn, opts, err := p.clientInit(ctx, cancel, conn, connID, p.conf) if err != nil { log.Errorw("Cannot initialize client connection", "error", err) return @@ -65,7 +67,7 @@ func (p *Proxy) accept(conn net.Conn) { stats.ClientConnected(opts.ConnectionType, clientConn.RemoteAddr()) defer stats.ClientDisconnected(opts.ConnectionType, clientConn.RemoteAddr()) - serverConn, err := p.getTelegramConn(opts, connID) + serverConn, err := p.getTelegramConn(ctx, cancel, opts, connID) if err != nil { log.Errorw("Cannot initialize server connection", "error", err) return @@ -92,8 +94,9 @@ func (p *Proxy) accept(conn net.Conn) { log.Infow("Client disconnected", "addr", conn.RemoteAddr()) } -func (p *Proxy) getTelegramConn(opts *mtproto.ConnectionOpts, connID string) (wrappers.Wrap, error) { - streamConn, err := p.tg.Dial(connID, opts) +func (p *Proxy) getTelegramConn(ctx context.Context, cancel context.CancelFunc, + opts *mtproto.ConnectionOpts, connID string) (wrappers.Wrap, error) { + streamConn, err := p.tg.Dial(ctx, cancel, connID, opts) if err != nil { return nil, errors.Annotate(err, "Cannot dial to Telegram") } diff --git a/telegram/dialer.go b/telegram/dialer.go index 9958403ef..8d192d6f8 100644 --- a/telegram/dialer.go +++ b/telegram/dialer.go @@ -1,6 +1,7 @@ package telegram import ( + "context" "net" "time" @@ -38,12 +39,14 @@ func (t *tgDialer) dial(addr string) (net.Conn, error) { return conn, nil } -func (t *tgDialer) dialRWC(addr, connID string) (wrappers.StreamReadWriteCloser, error) { +func (t *tgDialer) dialRWC(ctx context.Context, cancel context.CancelFunc, + addr, connID string) (wrappers.StreamReadWriteCloser, error) { conn, err := t.dial(addr) if err != nil { return nil, err } - tgConn := wrappers.NewConn(conn, connID, wrappers.ConnPurposeTelegram, t.conf.PublicIPv4, t.conf.PublicIPv6) + tgConn := wrappers.NewConn(ctx, cancel, conn, connID, + wrappers.ConnPurposeTelegram, t.conf.PublicIPv4, t.conf.PublicIPv6) return tgConn, nil } diff --git a/telegram/direct.go b/telegram/direct.go index e5dd6cfe2..0ad3fd078 100644 --- a/telegram/direct.go +++ b/telegram/direct.go @@ -1,6 +1,7 @@ package telegram import ( + "context" "net" "github.com/juju/errors" @@ -32,7 +33,8 @@ type directTelegram struct { baseTelegram } -func (t *directTelegram) Dial(connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) { +func (t *directTelegram) Dial(ctx context.Context, cancel context.CancelFunc, + connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) { dc := connOpts.DC if dc < 0 { dc = -dc @@ -40,7 +42,7 @@ func (t *directTelegram) Dial(connID string, connOpts *mtproto.ConnectionOpts) ( dc = 1 } - return t.baseTelegram.dial(dc-1, connID, connOpts.ConnectionProto) + return t.baseTelegram.dial(ctx, cancel, dc-1, connID, connOpts.ConnectionProto) } func (t *directTelegram) Init(connOpts *mtproto.ConnectionOpts, diff --git a/telegram/middle_caller.go b/telegram/middle_caller.go index f0e6721ff..33b44c65d 100644 --- a/telegram/middle_caller.go +++ b/telegram/middle_caller.go @@ -2,6 +2,7 @@ package telegram import ( "bufio" + "context" "io/ioutil" "net" "net/http" @@ -38,7 +39,7 @@ type middleTelegramCaller struct { httpClient *http.Client } -func (t *middleTelegramCaller) Dial(connID string, +func (t *middleTelegramCaller) Dial(ctx context.Context, cancel context.CancelFunc, connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) { dc := connOpts.DC if dc == 0 { @@ -47,7 +48,7 @@ func (t *middleTelegramCaller) Dial(connID string, t.dialerMutex.RLock() defer t.dialerMutex.RUnlock() - return t.baseTelegram.dial(dc, connID, connOpts.ConnectionProto) + return t.baseTelegram.dial(ctx, cancel, dc, connID, connOpts.ConnectionProto) } func (t *middleTelegramCaller) autoUpdate() { diff --git a/telegram/telegram.go b/telegram/telegram.go index 05617d435..d2b4010c8 100644 --- a/telegram/telegram.go +++ b/telegram/telegram.go @@ -1,6 +1,7 @@ package telegram import ( + "context" "math/rand" "github.com/juju/errors" @@ -11,7 +12,7 @@ import ( // Telegram is an interface for different Telegram work modes. type Telegram interface { - Dial(string, *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) + Dial(context.Context, context.CancelFunc, string, *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) Init(*mtproto.ConnectionOpts, wrappers.StreamReadWriteCloser) (wrappers.Wrap, error) } @@ -22,7 +23,7 @@ type baseTelegram struct { v6Addresses map[int16][]string } -func (b *baseTelegram) dial(dcIdx int16, connID string, +func (b *baseTelegram) dial(ctx context.Context, cancel context.CancelFunc, dcIdx int16, connID string, proto mtproto.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) { addrs := make([]string, 2) @@ -38,7 +39,7 @@ func (b *baseTelegram) dial(dcIdx int16, connID string, } for _, addr := range addrs { - if conn, err := b.dialer.dialRWC(addr, connID); err == nil { + if conn, err := b.dialer.dialRWC(ctx, cancel, addr, connID); err == nil { return conn, err } } diff --git a/wrappers/conn.go b/wrappers/conn.go index da6097c89..09e668f4b 100644 --- a/wrappers/conn.go +++ b/wrappers/conn.go @@ -1,12 +1,14 @@ package wrappers import ( + "context" "net" "time" "go.uber.org/zap" "github.com/9seconds/mtg/stats" + "github.com/juju/errors" ) // ConnPurpose is intended to be identifier of connection purpose. We @@ -39,8 +41,10 @@ const ( // Conn is a basic wrapper for net.Conn providing the most low-level // logic and management as possible. type Conn struct { - connID string conn net.Conn + ctx context.Context + cancel context.CancelFunc + connID string logger *zap.SugaredLogger publicIPv4 net.IP @@ -48,28 +52,46 @@ type Conn struct { } func (c *Conn) Write(p []byte) (int, error) { - c.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite)) // nolint: errcheck - n, err := c.conn.Write(p) + select { + case <-c.ctx.Done(): + return 0, errors.Annotate(c.ctx.Err(), "Cannot write because context was closed") + default: + c.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite)) // nolint: errcheck + n, err := c.conn.Write(p) + if err != nil { + c.cancel() + } - c.logger.Debugw("Write to stream", "bytes", n, "error", err) - stats.EgressTraffic(n) + c.logger.Debugw("Write to stream", "bytes", n, "error", err) + stats.EgressTraffic(n) - return n, err + return n, err + } } func (c *Conn) Read(p []byte) (int, error) { - c.conn.SetReadDeadline(time.Now().Add(connTimeoutRead)) // nolint: errcheck - n, err := c.conn.Read(p) + select { + case <-c.ctx.Done(): + return 0, errors.Annotate(c.ctx.Err(), "Cannot read because context was closed") + default: + c.conn.SetReadDeadline(time.Now().Add(connTimeoutRead)) // nolint: errcheck + n, err := c.conn.Read(p) + if err != nil { + c.cancel() + } - c.logger.Debugw("Read from stream", "bytes", n, "error", err) - stats.IngressTraffic(n) + c.logger.Debugw("Read from stream", "bytes", n, "error", err) + stats.IngressTraffic(n) - return n, err + return n, err + } } // Close closes underlying net.Conn instance. func (c *Conn) Close() error { defer c.logger.Debugw("Close connection") + + c.cancel() return c.conn.Close() } @@ -100,7 +122,8 @@ func (c *Conn) RemoteAddr() *net.TCPAddr { } // NewConn initializes Conn wrapper for net.Conn. -func NewConn(conn net.Conn, connID string, purpose ConnPurpose, publicIPv4, publicIPv6 net.IP) StreamReadWriteCloser { +func NewConn(ctx context.Context, cancel context.CancelFunc, conn net.Conn, + connID string, purpose ConnPurpose, publicIPv4, publicIPv6 net.IP) StreamReadWriteCloser { logger := zap.S().With( "connection_id", connID, "local_address", conn.LocalAddr(), @@ -109,9 +132,11 @@ func NewConn(conn net.Conn, connID string, purpose ConnPurpose, publicIPv4, publ ).Named("conn") wrapper := Conn{ - logger: logger, - connID: connID, conn: conn, + ctx: ctx, + cancel: cancel, + connID: connID, + logger: logger, publicIPv4: publicIPv4, publicIPv6: publicIPv6, }