Skip to content

Commit

Permalink
Merge pull request #33 from 9seconds/contexts
Browse files Browse the repository at this point in the history
Use contexts for Conn wrapper
  • Loading branch information
9seconds authored Jul 28, 2018
2 parents b0d86ab + 9f20e87 commit 243a89a
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 32 deletions.
4 changes: 3 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"net"

"github.com/9seconds/mtg/config"
Expand All @@ -9,4 +10,5 @@ import (
)

// Init defines common method for initializing client connections.
type Init func(net.Conn, string, *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error)
type Init func(context.Context, context.CancelFunc, net.Conn, string,
*config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error)
6 changes: 4 additions & 2 deletions client/direct.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"net"
"time"

Expand All @@ -16,7 +17,8 @@ const handshakeTimeout = 10 * time.Second

// DirectInit initializes client connection for proxy which connects to
// Telegram directly.
func DirectInit(socket net.Conn, connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
func DirectInit(ctx context.Context, cancel context.CancelFunc, socket net.Conn,
connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
tcpSocket := socket.(*net.TCPConn)
if err := tcpSocket.SetNoDelay(false); err != nil {
return nil, nil, errors.Annotate(err, "Cannot disable NO_DELAY to client socket")
Expand All @@ -35,7 +37,7 @@ func DirectInit(socket net.Conn, connID string, conf *config.Config) (wrappers.W
}
socket.SetReadDeadline(time.Time{}) // nolint: errcheck

conn := wrappers.NewConn(socket, connID, wrappers.ConnPurposeClient, conf.PublicIPv4, conf.PublicIPv6)
conn := wrappers.NewConn(ctx, cancel, socket, connID, wrappers.ConnPurposeClient, conf.PublicIPv4, conf.PublicIPv6)
obfs2, connOpts, err := obfuscated2.ParseObfuscated2ClientFrame(conf.Secret, frame)
if err != nil {
return nil, nil, errors.Annotate(err, "Cannot parse obfuscated frame")
Expand Down
6 changes: 4 additions & 2 deletions client/middle.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"net"

"github.com/9seconds/mtg/config"
Expand All @@ -10,8 +11,9 @@ import (

// MiddleInit initializes client connection for proxy which has to
// support promoted channels, connect to Telegram middle proxies etc.
func MiddleInit(socket net.Conn, connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
conn, opts, err := DirectInit(socket, connID, conf)
func MiddleInit(ctx context.Context, cancel context.CancelFunc, socket net.Conn,
connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) {
conn, opts, err := DirectInit(ctx, cancel, socket, connID, conf)
if err != nil {
return nil, nil, err
}
Expand Down
11 changes: 7 additions & 4 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"context"
"io"
"net"
"sync"
Expand Down Expand Up @@ -43,6 +44,7 @@ func (p *Proxy) Serve() error {
func (p *Proxy) accept(conn net.Conn) {
connID := uuid.NewV4().String()
log := zap.S().With("connection_id", connID).Named("main")
ctx, cancel := context.WithCancel(context.Background())

defer func() {
conn.Close() // nolint: errcheck
Expand All @@ -55,7 +57,7 @@ func (p *Proxy) accept(conn net.Conn) {

log.Infow("Client connected", "addr", conn.RemoteAddr())

clientConn, opts, err := p.clientInit(conn, connID, p.conf)
clientConn, opts, err := p.clientInit(ctx, cancel, conn, connID, p.conf)
if err != nil {
log.Errorw("Cannot initialize client connection", "error", err)
return
Expand All @@ -65,7 +67,7 @@ func (p *Proxy) accept(conn net.Conn) {
stats.ClientConnected(opts.ConnectionType, clientConn.RemoteAddr())
defer stats.ClientDisconnected(opts.ConnectionType, clientConn.RemoteAddr())

serverConn, err := p.getTelegramConn(opts, connID)
serverConn, err := p.getTelegramConn(ctx, cancel, opts, connID)
if err != nil {
log.Errorw("Cannot initialize server connection", "error", err)
return
Expand All @@ -92,8 +94,9 @@ func (p *Proxy) accept(conn net.Conn) {
log.Infow("Client disconnected", "addr", conn.RemoteAddr())
}

func (p *Proxy) getTelegramConn(opts *mtproto.ConnectionOpts, connID string) (wrappers.Wrap, error) {
streamConn, err := p.tg.Dial(connID, opts)
func (p *Proxy) getTelegramConn(ctx context.Context, cancel context.CancelFunc,
opts *mtproto.ConnectionOpts, connID string) (wrappers.Wrap, error) {
streamConn, err := p.tg.Dial(ctx, cancel, connID, opts)
if err != nil {
return nil, errors.Annotate(err, "Cannot dial to Telegram")
}
Expand Down
7 changes: 5 additions & 2 deletions telegram/dialer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package telegram

import (
"context"
"net"
"time"

Expand Down Expand Up @@ -38,12 +39,14 @@ func (t *tgDialer) dial(addr string) (net.Conn, error) {
return conn, nil
}

func (t *tgDialer) dialRWC(addr, connID string) (wrappers.StreamReadWriteCloser, error) {
func (t *tgDialer) dialRWC(ctx context.Context, cancel context.CancelFunc,
addr, connID string) (wrappers.StreamReadWriteCloser, error) {
conn, err := t.dial(addr)
if err != nil {
return nil, err
}
tgConn := wrappers.NewConn(conn, connID, wrappers.ConnPurposeTelegram, t.conf.PublicIPv4, t.conf.PublicIPv6)
tgConn := wrappers.NewConn(ctx, cancel, conn, connID,
wrappers.ConnPurposeTelegram, t.conf.PublicIPv4, t.conf.PublicIPv6)

return tgConn, nil
}
6 changes: 4 additions & 2 deletions telegram/direct.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package telegram

import (
"context"
"net"

"github.com/juju/errors"
Expand Down Expand Up @@ -32,15 +33,16 @@ type directTelegram struct {
baseTelegram
}

func (t *directTelegram) Dial(connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) {
func (t *directTelegram) Dial(ctx context.Context, cancel context.CancelFunc,
connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) {
dc := connOpts.DC
if dc < 0 {
dc = -dc
} else if dc == 0 {
dc = 1
}

return t.baseTelegram.dial(dc-1, connID, connOpts.ConnectionProto)
return t.baseTelegram.dial(ctx, cancel, dc-1, connID, connOpts.ConnectionProto)
}

func (t *directTelegram) Init(connOpts *mtproto.ConnectionOpts,
Expand Down
5 changes: 3 additions & 2 deletions telegram/middle_caller.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package telegram

import (
"bufio"
"context"
"io/ioutil"
"net"
"net/http"
Expand Down Expand Up @@ -38,7 +39,7 @@ type middleTelegramCaller struct {
httpClient *http.Client
}

func (t *middleTelegramCaller) Dial(connID string,
func (t *middleTelegramCaller) Dial(ctx context.Context, cancel context.CancelFunc, connID string,
connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) {
dc := connOpts.DC
if dc == 0 {
Expand All @@ -47,7 +48,7 @@ func (t *middleTelegramCaller) Dial(connID string,
t.dialerMutex.RLock()
defer t.dialerMutex.RUnlock()

return t.baseTelegram.dial(dc, connID, connOpts.ConnectionProto)
return t.baseTelegram.dial(ctx, cancel, dc, connID, connOpts.ConnectionProto)
}

func (t *middleTelegramCaller) autoUpdate() {
Expand Down
7 changes: 4 additions & 3 deletions telegram/telegram.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package telegram

import (
"context"
"math/rand"

"github.com/juju/errors"
Expand All @@ -11,7 +12,7 @@ import (

// Telegram is an interface for different Telegram work modes.
type Telegram interface {
Dial(string, *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error)
Dial(context.Context, context.CancelFunc, string, *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error)
Init(*mtproto.ConnectionOpts, wrappers.StreamReadWriteCloser) (wrappers.Wrap, error)
}

Expand All @@ -22,7 +23,7 @@ type baseTelegram struct {
v6Addresses map[int16][]string
}

func (b *baseTelegram) dial(dcIdx int16, connID string,
func (b *baseTelegram) dial(ctx context.Context, cancel context.CancelFunc, dcIdx int16, connID string,
proto mtproto.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) {
addrs := make([]string, 2)

Expand All @@ -38,7 +39,7 @@ func (b *baseTelegram) dial(dcIdx int16, connID string,
}

for _, addr := range addrs {
if conn, err := b.dialer.dialRWC(addr, connID); err == nil {
if conn, err := b.dialer.dialRWC(ctx, cancel, addr, connID); err == nil {
return conn, err
}
}
Expand Down
53 changes: 39 additions & 14 deletions wrappers/conn.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package wrappers

import (
"context"
"net"
"time"

"go.uber.org/zap"

"github.com/9seconds/mtg/stats"
"github.com/juju/errors"
)

// ConnPurpose is intended to be identifier of connection purpose. We
Expand Down Expand Up @@ -39,37 +41,57 @@ const (
// Conn is a basic wrapper for net.Conn providing the most low-level
// logic and management as possible.
type Conn struct {
connID string
conn net.Conn
ctx context.Context
cancel context.CancelFunc
connID string
logger *zap.SugaredLogger

publicIPv4 net.IP
publicIPv6 net.IP
}

func (c *Conn) Write(p []byte) (int, error) {
c.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite)) // nolint: errcheck
n, err := c.conn.Write(p)
select {
case <-c.ctx.Done():
return 0, errors.Annotate(c.ctx.Err(), "Cannot write because context was closed")
default:
c.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite)) // nolint: errcheck
n, err := c.conn.Write(p)
if err != nil {
c.cancel()
}

c.logger.Debugw("Write to stream", "bytes", n, "error", err)
stats.EgressTraffic(n)
c.logger.Debugw("Write to stream", "bytes", n, "error", err)
stats.EgressTraffic(n)

return n, err
return n, err
}
}

func (c *Conn) Read(p []byte) (int, error) {
c.conn.SetReadDeadline(time.Now().Add(connTimeoutRead)) // nolint: errcheck
n, err := c.conn.Read(p)
select {
case <-c.ctx.Done():
return 0, errors.Annotate(c.ctx.Err(), "Cannot read because context was closed")
default:
c.conn.SetReadDeadline(time.Now().Add(connTimeoutRead)) // nolint: errcheck
n, err := c.conn.Read(p)
if err != nil {
c.cancel()
}

c.logger.Debugw("Read from stream", "bytes", n, "error", err)
stats.IngressTraffic(n)
c.logger.Debugw("Read from stream", "bytes", n, "error", err)
stats.IngressTraffic(n)

return n, err
return n, err
}
}

// Close closes underlying net.Conn instance.
func (c *Conn) Close() error {
defer c.logger.Debugw("Close connection")

c.cancel()
return c.conn.Close()
}

Expand Down Expand Up @@ -100,7 +122,8 @@ func (c *Conn) RemoteAddr() *net.TCPAddr {
}

// NewConn initializes Conn wrapper for net.Conn.
func NewConn(conn net.Conn, connID string, purpose ConnPurpose, publicIPv4, publicIPv6 net.IP) StreamReadWriteCloser {
func NewConn(ctx context.Context, cancel context.CancelFunc, conn net.Conn,
connID string, purpose ConnPurpose, publicIPv4, publicIPv6 net.IP) StreamReadWriteCloser {
logger := zap.S().With(
"connection_id", connID,
"local_address", conn.LocalAddr(),
Expand All @@ -109,9 +132,11 @@ func NewConn(conn net.Conn, connID string, purpose ConnPurpose, publicIPv4, publ
).Named("conn")

wrapper := Conn{
logger: logger,
connID: connID,
conn: conn,
ctx: ctx,
cancel: cancel,
connID: connID,
logger: logger,
publicIPv4: publicIPv4,
publicIPv6: publicIPv6,
}
Expand Down

0 comments on commit 243a89a

Please sign in to comment.