diff --git a/Gopkg.lock b/Gopkg.lock index 88ad0e64e..7a8f120bd 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -22,6 +22,12 @@ revision = "346938d642f2ec3594ed81d874461961cd0faa76" version = "v1.1.0" +[[projects]] + branch = "master" + name = "github.com/dustin/go-humanize" + packages = ["."] + revision = "02af3965c54e8cacf948b97fef38925c4120652c" + [[projects]] branch = "master" name = "github.com/juju/errors" @@ -80,6 +86,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "24afdd6b64331aeba47fed75918d04032e13e404612cac107bad1d68a5038b72" + inputs-digest = "c4fdd3664f683342ad0c2509f4a8bcfe5b267a6e8cdaf36f70d39536bbf89834" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index ed9feb4de..ec43ce86f 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -44,3 +44,7 @@ [[constraint]] name = "github.com/satori/go.uuid" version = "1.2.0" + +[[constraint]] + branch = "master" + name = "github.com/dustin/go-humanize" diff --git a/README.md b/README.md index 2ecccd98b..980855de9 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,18 @@ mtg is an implementation in golang which is intended to be: * **No management WebUI** This is an implementation of simple lightweight proxy. I won't do that. +This proxy supports 2 modes of work: direct connection to Telegram and +promoted channel mode. If you do not need promoted channels, I would +recommend you to go with direct mode: this is way more robust. + +To run proxy in direct mode, all you need to do is just provide a +secret. If you do not provide ADTag as a second parameter, promoted +channels mode won't be activated. + +To get promoted channel, please contact +[@MTProxybot|https://t.me/MTProxybot] and provide generated adtag as a +second parameter. + # How to build diff --git a/client/client.go b/client/client.go index de5e32a31..fd503ff6e 100644 --- a/client/client.go +++ b/client/client.go @@ -1,12 +1,12 @@ package client import ( - "io" "net" "github.com/9seconds/mtg/config" "github.com/9seconds/mtg/mtproto" + "github.com/9seconds/mtg/wrappers" ) -// Init has to initialize client connection based on given config. -type Init func(net.Conn, *config.Config) (*mtproto.ConnectionOpts, io.ReadWriteCloser, error) +// Init defines common method for initializing client connections. +type Init func(net.Conn, string, *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) diff --git a/client/direct.go b/client/direct.go index 537d12a5f..de8b66fe2 100644 --- a/client/direct.go +++ b/client/direct.go @@ -1,7 +1,6 @@ package client import ( - "io" "net" "time" @@ -15,28 +14,42 @@ import ( const ( handshakeTimeout = 10 * time.Second + readBufferSize = 64 * 1024 + writeBufferSize = 64 * 1024 ) -// DirectInit initializes client to access Telegram bypassing middleproxies. -func DirectInit(conn net.Conn, conf *config.Config) (*mtproto.ConnectionOpts, io.ReadWriteCloser, error) { - if err := config.SetSocketOptions(conn); err != nil { - return nil, nil, errors.Annotate(err, "Cannot set socket options") +// DirectInit initializes client connection for proxy which connects to +// Telegram directly. +func DirectInit(socket net.Conn, connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) { + tcpSocket := socket.(*net.TCPConn) + if err := tcpSocket.SetNoDelay(false); err != nil { + return nil, nil, errors.Annotate(err, "Cannot disable NO_DELAY to client socket") + } + if err := tcpSocket.SetReadBuffer(readBufferSize); err != nil { + return nil, nil, errors.Annotate(err, "Cannot set read buffer size of client socket") + } + if err := tcpSocket.SetWriteBuffer(writeBufferSize); err != nil { + return nil, nil, errors.Annotate(err, "Cannot set write buffer size of client socket") } - conn.SetReadDeadline(time.Now().Add(handshakeTimeout)) // nolint: errcheck - frame, err := obfuscated2.ExtractFrame(conn) - conn.SetReadDeadline(time.Time{}) // nolint: errcheck + socket.SetReadDeadline(time.Now().Add(handshakeTimeout)) // nolint: errcheck + frame, err := obfuscated2.ExtractFrame(socket) if err != nil { return nil, nil, errors.Annotate(err, "Cannot extract frame") } - defer obfuscated2.ReturnFrame(frame) + socket.SetReadDeadline(time.Time{}) // nolint: errcheck + conn := wrappers.NewConn(socket, connID, wrappers.ConnPurposeClient, conf.PublicIPv4, conf.PublicIPv6) obfs2, connOpts, err := obfuscated2.ParseObfuscated2ClientFrame(conf.Secret, frame) if err != nil { return nil, nil, errors.Annotate(err, "Cannot parse obfuscated frame") } + connOpts.ConnectionProto = mtproto.ConnectionProtocolAny + connOpts.ClientAddr = conn.RemoteAddr() + + conn = wrappers.NewStreamCipher(conn, obfs2.Encryptor, obfs2.Decryptor) - socket := wrappers.NewStreamCipherRWC(conn, obfs2.Encryptor, obfs2.Decryptor) + conn.Logger().Infow("Client connection initialized") - return connOpts, socket, nil + return conn, connOpts, nil } diff --git a/client/middle.go b/client/middle.go new file mode 100644 index 000000000..53bb64c74 --- /dev/null +++ b/client/middle.go @@ -0,0 +1,31 @@ +package client + +import ( + "net" + + "github.com/9seconds/mtg/config" + "github.com/9seconds/mtg/mtproto" + "github.com/9seconds/mtg/wrappers" +) + +// MiddleInit initializes client connection for proxy which has to +// support promoted channels, connect to Telegram middle proxies etc. +func MiddleInit(socket net.Conn, connID string, conf *config.Config) (wrappers.Wrap, *mtproto.ConnectionOpts, error) { + conn, opts, err := DirectInit(socket, connID, conf) + if err != nil { + return nil, nil, err + } + connStream := conn.(wrappers.StreamReadWriteCloser) + + newConn := wrappers.NewMTProtoAbridged(connStream, opts) + if opts.ConnectionType != mtproto.ConnectionTypeAbridged { + newConn = wrappers.NewMTProtoIntermediate(connStream, opts) + } + + opts.ConnectionProto = mtproto.ConnectionProtocolIPv4 + if socket.LocalAddr().(*net.TCPAddr).IP.To4() == nil { + opts.ConnectionProto = mtproto.ConnectionProtocolIPv6 + } + + return newConn, opts, err +} diff --git a/config/config.go b/config/config.go index 54f4cba05..f5fd97ffd 100644 --- a/config/config.go +++ b/config/config.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "strconv" - "time" "github.com/juju/errors" ) @@ -14,9 +13,6 @@ import ( const ( BufferWriteSize = 32 * 1024 BufferReadSize = 32 * 1024 - BufferSizeCopy = 32 * 1024 - - keepAlivePeriod = 20 * time.Second ) // Config represents common configuration of mtg. @@ -35,6 +31,7 @@ type Config struct { StatsIP net.IP Secret []byte + AdTag []byte } // URLs contains links to the proxy (tg://, t.me) and their QR codes. @@ -56,27 +53,28 @@ func (c *Config) BindAddr() string { return getAddr(c.BindIP, c.BindPort) } -// IPv4Addr returns connection string to ipv6 for mtproto proxy. -func (c *Config) IPv4Addr() string { - return getAddr(c.PublicIPv4, c.PublicIPv4Port) -} - -// IPv6Addr returns connection string to ipv6 for mtproto proxy. -func (c *Config) IPv6Addr() string { - return getAddr(c.PublicIPv6, c.PublicIPv6Port) -} - // StatAddr returns connection string to the stats API. func (c *Config) StatAddr() string { return getAddr(c.StatsIP, c.StatsPort) } +// UseMiddleProxy defines if this proxy has to connect middle proxies +// which supports promoted channels or directly access Telegram. +func (c *Config) UseMiddleProxy() bool { + return len(c.AdTag) > 0 +} + // GetURLs returns configured IPURLs instance with links to this server. func (c *Config) GetURLs() IPURLs { - return IPURLs{ - IPv4: getURLs(c.PublicIPv4, c.PublicIPv4Port, c.Secret), - IPv6: getURLs(c.PublicIPv6, c.PublicIPv6Port, c.Secret), + urls := IPURLs{} + if c.PublicIPv4 != nil { + urls.IPv4 = getURLs(c.PublicIPv4, c.PublicIPv4Port, c.Secret) + } + if c.PublicIPv6 != nil { + urls.IPv6 = getURLs(c.PublicIPv6, c.PublicIPv6Port, c.Secret) } + + return urls } func getAddr(host fmt.Stringer, port uint16) string { @@ -91,7 +89,7 @@ func NewConfig(debug, verbose bool, // nolint: gocyclo publicIPv4 net.IP, PublicIPv4Port uint16, publicIPv6 net.IP, publicIPv6Port uint16, statsIP net.IP, statsPort uint16, - secret string) (*Config, error) { + secret, adtag string) (*Config, error) { if len(secret) != 32 { return nil, errors.New("Telegram demands secret of length 32") } @@ -100,15 +98,22 @@ func NewConfig(debug, verbose bool, // nolint: gocyclo return nil, errors.Annotate(err, "Cannot create config") } + var adTagBytes []byte + if len(adtag) != 0 { + adTagBytes, err = hex.DecodeString(adtag) + if err != nil { + return nil, errors.Annotate(err, "Cannot create config") + } + } + if publicIPv4 == nil { publicIPv4, err = getGlobalIPv4() if err != nil { - return nil, errors.Errorf("Cannot get public IP") + publicIPv4 = nil + } else if publicIPv4.To4() == nil { + return nil, errors.Errorf("IP %s is not IPv4", publicIPv4.String()) } } - if publicIPv4.To4() == nil { - return nil, errors.Errorf("IP %s is not IPv4", publicIPv4.String()) - } if PublicIPv4Port == 0 { PublicIPv4Port = bindPort } @@ -116,12 +121,11 @@ func NewConfig(debug, verbose bool, // nolint: gocyclo if publicIPv6 == nil { publicIPv6, err = getGlobalIPv6() if err != nil { - publicIPv6 = publicIPv4 + publicIPv6 = nil + } else if publicIPv6.To4() != nil { + return nil, errors.Errorf("IP %s is not IPv6", publicIPv6.String()) } } - if publicIPv6.To16() == nil { - return nil, errors.Errorf("IP %s is not IPv6", publicIPv6.String()) - } if publicIPv6Port == 0 { publicIPv6Port = bindPort } @@ -142,30 +146,8 @@ func NewConfig(debug, verbose bool, // nolint: gocyclo StatsIP: statsIP, StatsPort: statsPort, Secret: secretBytes, + AdTag: adTagBytes, } return conf, nil } - -// SetSocketOptions makes socket keepalive, sets buffer sizes -func SetSocketOptions(conn net.Conn) error { - socket := conn.(*net.TCPConn) - - if err := socket.SetReadBuffer(BufferReadSize); err != nil { - return errors.Annotate(err, "Cannot set read buffer size") - } - if err := socket.SetWriteBuffer(BufferWriteSize); err != nil { - return errors.Annotate(err, "Cannot set write buffer size") - } - if err := socket.SetKeepAlive(true); err != nil { - return errors.Annotate(err, "Cannot make socket keepalive") - } - if err := socket.SetKeepAlivePeriod(keepAlivePeriod); err != nil { - return errors.Annotate(err, "Cannot set keepalive period") - } - if err := socket.SetNoDelay(true); err != nil { - return errors.Annotate(err, "Cannot activate nodelay for the socket") - } - - return nil -} diff --git a/main.go b/main.go index fa63cfcc2..e7d8a16ad 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "github.com/9seconds/mtg/config" "github.com/9seconds/mtg/proxy" + "github.com/9seconds/mtg/stats" "github.com/juju/errors" ) @@ -70,6 +71,7 @@ var ( Uint16() secret = app.Arg("secret", "Secret of this proxy.").Required().String() + adtag = app.Arg("adtag", "ADTag of the proxy.").String() ) func init() { @@ -91,7 +93,7 @@ func main() { *publicIPv4, *publicIPv4Port, *publicIPv6, *publicIPv6Port, *statsIP, *statsPort, - *secret, + *secret, *adtag, ) if err != nil { usage(err.Error()) @@ -110,16 +112,23 @@ func main() { zapcore.NewJSONEncoder(encoderCfg), zapcore.Lock(os.Stderr), atom, - )).Sugar() + )) + zap.ReplaceGlobals(logger) + defer logger.Sync() // nolint: errcheck - stat := proxy.NewStats(conf) - go stat.Serve() - - srv := proxy.NewServer(conf, logger, stat) printURLs(conf.GetURLs()) - if err := srv.Serve(); err != nil { - logger.Fatal(err.Error()) + if conf.UseMiddleProxy() { + zap.S().Infow("Use middle proxy connection to Telegram") + } else { + zap.S().Infow("Use direct connection to Telegram") + } + + go stats.Start(conf) + + server := proxy.NewProxy(conf) + if err := server.Serve(); err != nil { + zap.S().Fatalw("Server stopped", "error", err) } } diff --git a/mtproto/connection_options.go b/mtproto/connection_options.go index 8510692af..76d2bfebd 100644 --- a/mtproto/connection_options.go +++ b/mtproto/connection_options.go @@ -2,6 +2,7 @@ package mtproto import ( "bytes" + "net" "github.com/juju/errors" ) @@ -10,11 +11,27 @@ import ( // by the user. type ConnectionType uint8 +// ConnectionProtocol is a type of IP protocol to use. +type ConnectionProtocol uint8 + +// Hacks is a simple structure to store flags for packet transmission. +type Hacks struct { + SimpleAck bool + QuickAck bool +} + // ConnectionOpts presents an options, metadata on connection requested // by the user on handshake. type ConnectionOpts struct { - DC int16 - ConnectionType ConnectionType + DC int16 + ConnectionType ConnectionType + ConnectionProto ConnectionProtocol + // Read and Write means direction related to the client. + // ReadHacks are meant to be flushed on client read + // WriteHacks are meant to be flushed on client write. + ReadHacks Hacks + WriteHacks Hacks + ClientAddr *net.TCPAddr } // Different connection types which user requests from Telegram. @@ -24,6 +41,14 @@ const ( ConnectionTypeIntermediate ) +// ConnectionProtocol* define which connection protocols to use. +// ConnectionProtocolAny means that any is suitable. +const ( + ConnectionProtocolIPv4 ConnectionProtocol = 1 + ConnectionProtocolIPv6 = ConnectionProtocolIPv4 << 1 + ConnectionProtocolAny = ConnectionProtocolIPv4 | ConnectionProtocolIPv6 +) + // Connection tags for mtproto handshakes. var ( ConnectionTagAbridged = []byte{0xef, 0xef, 0xef, 0xef} diff --git a/mtproto/rpc/handshake_request.go b/mtproto/rpc/handshake_request.go new file mode 100644 index 000000000..acf380726 --- /dev/null +++ b/mtproto/rpc/handshake_request.go @@ -0,0 +1,26 @@ +package rpc + +import "bytes" + +// HandshakeRequest is the data type which is responsible for +// constructing of correct handshake request. +type HandshakeRequest struct { +} + +// Bytes returns serialized handshake request. +func (r *HandshakeRequest) Bytes() []byte { + buf := &bytes.Buffer{} + buf.Grow(len(TagHandshake) + len(HandshakeFlags) + len(HandshakeSenderPID) + len(HandshakePeerPID)) + + buf.Write(TagHandshake) + buf.Write(HandshakeFlags) + buf.Write(HandshakeSenderPID) + buf.Write(HandshakePeerPID) + + return buf.Bytes() +} + +// NewHandshakeRequest creates new HandshakeRequest instance. +func NewHandshakeRequest() *HandshakeRequest { + return &HandshakeRequest{} +} diff --git a/mtproto/rpc/handshake_response.go b/mtproto/rpc/handshake_response.go new file mode 100644 index 000000000..a8c522c1f --- /dev/null +++ b/mtproto/rpc/handshake_response.go @@ -0,0 +1,55 @@ +package rpc + +import ( + "bytes" + + "github.com/juju/errors" +) + +// HandshakeResponse defines data structure which is used for storage of +// handshake response. +type HandshakeResponse struct { + Type []byte + Flags []byte + SenderPID []byte + PeerPID []byte +} + +// Bytes returns a serialized handshake response. +func (r *HandshakeResponse) Bytes() []byte { + buf := &bytes.Buffer{} + + buf.Write(r.Type[:]) + buf.Write(r.Flags[:]) + buf.Write(r.SenderPID[:]) + buf.Write(r.PeerPID[:]) + + return buf.Bytes() +} + +// Valid checks that handshake response compliments request. +func (r *HandshakeResponse) Valid(req *HandshakeRequest) error { + if !bytes.Equal(r.Type, TagHandshake) { + return errors.New("Unexpected handshake tag") + } + if !bytes.Equal(r.PeerPID, HandshakeSenderPID) { + return errors.New("Incorrect sender PID") + } + + return nil +} + +// NewHandshakeResponse constructs new handshake response from the given +// data. +func NewHandshakeResponse(data []byte) (*HandshakeResponse, error) { + if len(data) != 32 { + return nil, errors.New("Incorrect handshake response length") + } + + return &HandshakeResponse{ + Type: data[:4], + Flags: data[4:8], + SenderPID: data[8:20], + PeerPID: data[20:], + }, nil +} diff --git a/mtproto/rpc/nonce_request.go b/mtproto/rpc/nonce_request.go new file mode 100644 index 000000000..b714a9c78 --- /dev/null +++ b/mtproto/rpc/nonce_request.go @@ -0,0 +1,52 @@ +package rpc + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "time" + + "github.com/juju/errors" +) + +// NonceRequest is the data type which contains all the data for correct +// nonce request. +type NonceRequest struct { + KeySelector []byte + CryptoTS []byte + Nonce []byte +} + +// Bytes returns serialized nonce request. +func (r *NonceRequest) Bytes() []byte { + buf := &bytes.Buffer{} + + buf.Write(TagNonce) + buf.Write(r.KeySelector) + buf.Write(NonceCryptoAES) + buf.Write(r.CryptoTS) + buf.Write(r.Nonce) + + return buf.Bytes() +} + +// NewNonceRequest builds new none request based on proxy secret. +func NewNonceRequest(proxySecret []byte) (*NonceRequest, error) { + nonce := make([]byte, 16) + keySelector := make([]byte, 4) + cryptoTS := make([]byte, 4) + + if _, err := rand.Read(nonce); err != nil { + return nil, errors.Annotate(err, "Cannot generate nonce") + } + copy(keySelector, proxySecret) + + timestamp := time.Now().Truncate(time.Second).Unix() % 4294967296 // 256 ^ 4 - do not know how to name + binary.LittleEndian.PutUint32(cryptoTS, uint32(timestamp)) + + return &NonceRequest{ + KeySelector: keySelector, + CryptoTS: cryptoTS, + Nonce: nonce, + }, nil +} diff --git a/mtproto/rpc/nonce_response.go b/mtproto/rpc/nonce_response.go new file mode 100644 index 000000000..f76bd44f7 --- /dev/null +++ b/mtproto/rpc/nonce_response.go @@ -0,0 +1,60 @@ +package rpc + +import ( + "bytes" + + "github.com/juju/errors" +) + +// NonceResponse is the data type which contains data of nonce response. +type NonceResponse struct { + NonceRequest + + Type []byte + Crypto []byte +} + +// Bytes returns serialized form of the nonce response. +func (r *NonceResponse) Bytes() []byte { + buf := &bytes.Buffer{} + + buf.Write(r.Type) + buf.Write(r.KeySelector) + buf.Write(r.Crypto) + buf.Write(r.CryptoTS) + buf.Write(r.Nonce) + + return buf.Bytes() +} + +// Valid checks that nonce response compliments nonce request. +func (r *NonceResponse) Valid(req *NonceRequest) error { + if !bytes.Equal(r.Type, TagNonce) { + return errors.New("Unexpected RPC type") + } + if !bytes.Equal(r.Crypto, NonceCryptoAES) { + return errors.New("Unexpected crypto type") + } + if !bytes.Equal(r.KeySelector, req.KeySelector) { + return errors.New("Unexpected key selector") + } + + return nil +} + +// NewNonceResponse build new nonce response based on the given data. +func NewNonceResponse(data []byte) (*NonceResponse, error) { + if len(data) != 32 { + return nil, errors.New("Unexpected message length") + } + + return &NonceResponse{ + NonceRequest: NonceRequest{ + KeySelector: data[4:8], + CryptoTS: data[12:16], + Nonce: data[16:], + }, + Type: data[:4], + Crypto: data[8:12], + }, nil +} diff --git a/mtproto/rpc/proxy_flags.go b/mtproto/rpc/proxy_flags.go new file mode 100644 index 000000000..559953717 --- /dev/null +++ b/mtproto/rpc/proxy_flags.go @@ -0,0 +1,55 @@ +package rpc + +import ( + "encoding/binary" + "strings" +) + +type proxyRequestFlags uint32 + +const ( + proxyRequestFlagsHasAdTag proxyRequestFlags = 0x8 + proxyRequestFlagsEncrypted = 0x2 + proxyRequestFlagsMagic = 0x1000 + proxyRequestFlagsExtMode2 = 0x20000 + proxyRequestFlagsIntermediate = 0x20000000 + proxyRequestFlagsAbdridged = 0x40000000 + proxyRequestFlagsQuickAck = 0x80000000 +) + +var proxyRequestFlagsEncryptedPrefix [8]byte + +func (r proxyRequestFlags) Bytes() []byte { + converted := make([]byte, 4) + binary.LittleEndian.PutUint32(converted, uint32(r)) + + return converted +} + +func (r proxyRequestFlags) String() string { + flags := make([]string, 0, 7) + + if r&proxyRequestFlagsHasAdTag != 0 { + flags = append(flags, "HAS_AD_TAG") + } + if r&proxyRequestFlagsEncrypted != 0 { + flags = append(flags, "ENCRYPTED") + } + if r&proxyRequestFlagsMagic != 0 { + flags = append(flags, "MAGIC") + } + if r&proxyRequestFlagsExtMode2 != 0 { + flags = append(flags, "EXT_MODE_2") + } + if r&proxyRequestFlagsIntermediate != 0 { + flags = append(flags, "INTERMEDIATE") + } + if r&proxyRequestFlagsAbdridged != 0 { + flags = append(flags, "ABRIDGED") + } + if r&proxyRequestFlagsQuickAck != 0 { + flags = append(flags, "QUICK_ACK") + } + + return strings.Join(flags, " | ") +} diff --git a/mtproto/rpc/proxy_request.go b/mtproto/rpc/proxy_request.go new file mode 100644 index 000000000..9e2dacf7b --- /dev/null +++ b/mtproto/rpc/proxy_request.go @@ -0,0 +1,100 @@ +package rpc + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "fmt" + "net" + + "github.com/juju/errors" + + "github.com/9seconds/mtg/mtproto" +) + +// ProxyRequest is the data type for storing data required to compose +// RPC_PROXY_REQ request. +type ProxyRequest struct { + Flags proxyRequestFlags + ConnectionID []byte + OurIPPort []byte + ClientIPPort []byte + ADTag []byte + Options *mtproto.ConnectionOpts +} + +// MakeHeader makes RPC_PROXY_REQ header. We need only to append the +// data for it. +func (r *ProxyRequest) MakeHeader(message []byte) (*bytes.Buffer, fmt.Stringer) { + bufferLength := len(TagProxyRequest) + + 4 + // len(flags) + len(r.ConnectionID) + + len(r.ClientIPPort) + + len(r.OurIPPort) + + len(ProxyRequestExtraSize) + + len(ProxyRequestProxyTag) + + 1 + // len(AdTag) + len(r.ADTag) + bufferLength += bufferLength % 4 + + buf := &bytes.Buffer{} + buf.Grow(bufferLength + len(message)) + + flags := r.Flags + if r.Options.ReadHacks.QuickAck { + flags |= proxyRequestFlagsQuickAck + } + + if bytes.HasPrefix(message, proxyRequestFlagsEncryptedPrefix[:]) { + flags |= proxyRequestFlagsEncrypted + } + + buf.Write(TagProxyRequest) + buf.Write(flags.Bytes()) + buf.Write(r.ConnectionID) + buf.Write(r.ClientIPPort) + buf.Write(r.OurIPPort) + buf.Write(ProxyRequestExtraSize) + buf.Write(ProxyRequestProxyTag) + buf.WriteByte(byte(len(r.ADTag))) + buf.Write(r.ADTag) + buf.Write(make([]byte, (4-buf.Len()%4)%4)) + + return buf, flags +} + +// NewProxyRequest build new ProxyRequest data structure. +func NewProxyRequest(clientAddr, ownAddr *net.TCPAddr, opts *mtproto.ConnectionOpts, adTag []byte) (*ProxyRequest, error) { + flags := proxyRequestFlagsHasAdTag | proxyRequestFlagsMagic | proxyRequestFlagsExtMode2 + + switch opts.ConnectionType { + case mtproto.ConnectionTypeAbridged: + flags |= proxyRequestFlagsAbdridged + case mtproto.ConnectionTypeIntermediate: + flags |= proxyRequestFlagsIntermediate + } + + request := &ProxyRequest{ + Flags: flags, + ADTag: adTag, + Options: opts, + ConnectionID: make([]byte, 8), + ClientIPPort: make([]byte, 16+4), + OurIPPort: make([]byte, 16+4), + } + + if _, err := rand.Read(request.ConnectionID); err != nil { + return nil, errors.Annotate(err, "Cannot generate connection ID") + } + + port := [4]byte{} + copy(request.ClientIPPort[:16], clientAddr.IP.To16()) + binary.LittleEndian.PutUint32(port[:], uint32(clientAddr.Port)) + copy(request.ClientIPPort[16:], port[:]) + + copy(request.OurIPPort[:16], ownAddr.IP.To16()) + binary.LittleEndian.PutUint32(port[:], uint32(ownAddr.Port)) + copy(request.OurIPPort[16:], port[:]) + + return request, nil +} diff --git a/mtproto/rpc/rpc.go b/mtproto/rpc/rpc.go new file mode 100644 index 000000000..5d242f568 --- /dev/null +++ b/mtproto/rpc/rpc.go @@ -0,0 +1,33 @@ +package rpc + +// SeqNo* is the number of the sequence which have special meaning for +// the Telegram. +const ( + SeqNoNonce = -2 + SeqNoHandshake = -1 +) + +// Different constants for RPC protocol +var ( + TagCloseExt = []byte{0xa2, 0x34, 0xb6, 0x5e} + TagProxyAns = []byte{0x0d, 0xda, 0x03, 0x44} + TagSimpleAck = []byte{0x9b, 0x40, 0xac, 0x3b} + TagHandshake = []byte{0xf5, 0xee, 0x82, 0x76} + TagNonce = []byte{0xaa, 0x87, 0xcb, 0x7a} + TagProxyRequest = []byte{0xee, 0xf1, 0xce, 0x36} + + NonceCryptoAES = []byte{0x01, 0x00, 0x00, 0x00} + + HandshakeFlags = []byte{0x00, 0x00, 0x00, 0x00} + + ProxyRequestExtraSize = []byte{0x18, 0x00, 0x00, 0x00} + ProxyRequestProxyTag = []byte{0xae, 0x26, 0x1e, 0xdb} + + HandshakeSenderPID []byte + HandshakePeerPID []byte +) + +func init() { + HandshakeSenderPID = []byte("IPIPPRPDTIME") + HandshakePeerPID = []byte("IPIPPRPDTIME") +} diff --git a/obfuscated2/frame.go b/obfuscated2/frame.go index ea03241ca..6c6f6abc1 100644 --- a/obfuscated2/frame.go +++ b/obfuscated2/frame.go @@ -66,57 +66,55 @@ func (f Frame) ConnectionType() (mtproto.ConnectionType, error) { // Invert inverts frame for extracting encryption keys. Pkease check that link: // https://blog.susanka.eu/how-telegram-obfuscates-its-mtproto-traffic/ -func (f Frame) Invert() *Frame { - reversed := MakeFrame() - copy(*reversed, f) +func (f Frame) Invert() Frame { + reversed := make(Frame, FrameLen) + copy(reversed, f) for i := 0; i < frameLenKey+frameLenIV; i++ { - (*reversed)[frameOffsetFirst+i] = f[frameOffsetIV-1-i] + reversed[frameOffsetFirst+i] = f[frameOffsetIV-1-i] } return reversed } // ExtractFrame extracts exact obfuscated2 handshake frame from given reader. -func ExtractFrame(conn io.Reader) (*Frame, error) { - frame := MakeFrame() - buf := bytes.NewBuffer(*frame) +func ExtractFrame(conn io.Reader) (Frame, error) { + frame := make(Frame, FrameLen) + buf := bytes.NewBuffer(frame) buf.Reset() if _, err := io.CopyN(buf, conn, FrameLen); err != nil { - ReturnFrame(frame) return nil, errors.Annotate(err, "Cannot extract obfuscated header") } - copy(*frame, buf.Bytes()) + copy(frame, buf.Bytes()) return frame, nil } -func generateFrame(connectionType mtproto.ConnectionType) *Frame { - frame := MakeFrame() - data := *frame +func generateFrame(connectionType mtproto.ConnectionType) Frame { + frame := make(Frame, FrameLen) for { - if _, err := rand.Read(data); err != nil { + if _, err := rand.Read(frame); err != nil { continue } - if data[0] == 0xef { + if frame[0] == 0xef { continue } - val := (uint32(data[3]) << 24) | (uint32(data[2]) << 16) | (uint32(data[1]) << 8) | uint32(data[0]) + val := (uint32(frame[3]) << 24) | (uint32(frame[2]) << 16) | (uint32(frame[1]) << 8) | uint32(frame[0]) if val == 0x44414548 || val == 0x54534f50 || val == 0x20544547 || val == 0x4954504f || val == 0xeeeeeeee { continue } - val = (uint32(data[7]) << 24) | (uint32(data[6]) << 16) | (uint32(data[5]) << 8) | uint32(data[4]) + val = (uint32(frame[7]) << 24) | (uint32(frame[6]) << 16) | (uint32(frame[5]) << 8) | uint32(frame[4]) if val == 0x00000000 { continue } // error has to be checked before calling this function tag, _ := connectionType.Tag() // nolint: errcheck - copy(data.Magic(), tag) + copy(frame.Magic(), tag) return frame } diff --git a/obfuscated2/frame_pool.go b/obfuscated2/frame_pool.go deleted file mode 100644 index 431073fe2..000000000 --- a/obfuscated2/frame_pool.go +++ /dev/null @@ -1,24 +0,0 @@ -package obfuscated2 - -import "sync" - -var framePool sync.Pool - -// MakeFrame returns new pointer to the handshake frame. -func MakeFrame() *Frame { - return framePool.Get().(*Frame) -} - -// ReturnFrame returns pointer to the handshake frame back to the pool. -func ReturnFrame(f *Frame) { - framePool.Put(f) -} - -func init() { - framePool = sync.Pool{ - New: func() interface{} { - data := make(Frame, FrameLen) - return &data - }, - } -} diff --git a/obfuscated2/frame_test.go b/obfuscated2/frame_test.go index 0c7e69208..42d997c77 100644 --- a/obfuscated2/frame_test.go +++ b/obfuscated2/frame_test.go @@ -54,21 +54,21 @@ func TestFrameValid(t *testing.T) { func TestFrameDoubleInvert(t *testing.T) { frame := makeFrame() - assert.True(t, bytes.Equal(frame, *frame.Invert().Invert())) + assert.True(t, bytes.Equal(frame, frame.Invert().Invert())) } func TestFrameInvert(t *testing.T) { frame := makeFrame() reversed := frame.Invert() - assert.Exactly(t, frame[:8], (*reversed)[:8]) - assert.Exactly(t, frame[56:], (*reversed)[56:]) + assert.Exactly(t, frame[:8], reversed[:8]) + assert.Exactly(t, frame[56:], reversed[56:]) toCompare := make([]byte, 48) for i := 0; i < 48; i++ { toCompare[i] = frame[55-i] } - assert.Equal(t, []byte((*reversed)[8:56]), toCompare) + assert.Equal(t, []byte(reversed[8:56]), toCompare) } func TestFrameGenerateValid(t *testing.T) { diff --git a/obfuscated2/obfuscated2.go b/obfuscated2/obfuscated2.go index bc027adb1..50a34c794 100644 --- a/obfuscated2/obfuscated2.go +++ b/obfuscated2/obfuscated2.go @@ -21,7 +21,7 @@ type Obfuscated2 struct { // details: http://telegra.ph/telegram-blocks-wtf-05-26 // // Beware, link above is in russian. -func ParseObfuscated2ClientFrame(secret []byte, frame *Frame) (*Obfuscated2, *mtproto.ConnectionOpts, error) { +func ParseObfuscated2ClientFrame(secret []byte, frame Frame) (*Obfuscated2, *mtproto.ConnectionOpts, error) { decHasher := sha256.New() decHasher.Write(frame.Key()) // nolint: errcheck decHasher.Write(secret) // nolint: errcheck @@ -33,9 +33,8 @@ func ParseObfuscated2ClientFrame(secret []byte, frame *Frame) (*Obfuscated2, *mt encHasher.Write(secret) // nolint: errcheck encryptor := makeStreamCipher(encHasher.Sum(nil), invertedFrame.IV()) - decryptedFrame := MakeFrame() - defer ReturnFrame(decryptedFrame) - decryptor.XORKeyStream(*decryptedFrame, *frame) + decryptedFrame := make(Frame, FrameLen) + decryptor.XORKeyStream(decryptedFrame, frame) connType, err := decryptedFrame.ConnectionType() if err != nil { return nil, nil, errors.Annotate(err, "Unknown protocol") @@ -56,18 +55,17 @@ func ParseObfuscated2ClientFrame(secret []byte, frame *Frame) (*Obfuscated2, *mt // MakeTelegramObfuscated2Frame creates new handshake frame to send to // Telegram. // https://blog.susanka.eu/how-telegram-obfuscates-its-mtproto-traffic/ -func MakeTelegramObfuscated2Frame(opts *mtproto.ConnectionOpts) (*Obfuscated2, *Frame) { +func MakeTelegramObfuscated2Frame(opts *mtproto.ConnectionOpts) (*Obfuscated2, Frame) { frame := generateFrame(opts.ConnectionType) encryptor := makeStreamCipher(frame.Key(), frame.IV()) decryptorFrame := frame.Invert() decryptor := makeStreamCipher(decryptorFrame.Key(), decryptorFrame.IV()) - copyFrame := MakeFrame() - defer ReturnFrame(copyFrame) - copy((*copyFrame)[:frameOffsetIV], (*frame)[:frameOffsetIV]) - encryptor.XORKeyStream(*frame, *frame) - copy((*frame)[:frameOffsetIV], (*copyFrame)[:frameOffsetIV]) + copyFrame := make(Frame, FrameLen) + copy(copyFrame[:frameOffsetIV], frame[:frameOffsetIV]) + encryptor.XORKeyStream(frame, frame) + copy(frame[:frameOffsetIV], copyFrame[:frameOffsetIV]) obfs := &Obfuscated2{ Decryptor: decryptor, diff --git a/obfuscated2/obfuscated2_test.go b/obfuscated2/obfuscated2_test.go index 0c7b1a8aa..c6b5c4ab8 100644 --- a/obfuscated2/obfuscated2_test.go +++ b/obfuscated2/obfuscated2_test.go @@ -18,7 +18,7 @@ func TestObfs2TelegramFrameDecrypt(t *testing.T) { decryptor := makeStreamCipher(frame.Key(), frame.IV()) decrypted := make(Frame, FrameLen) - decryptor.XORKeyStream(decrypted, *frame) + decryptor.XORKeyStream(decrypted, frame) _, err := decrypted.ConnectionType() assert.Nil(t, err) @@ -53,8 +53,8 @@ func TestObfs2Full(t *testing.T) { encryptor := makeStreamCipher(clientKey, clientFrame.IV()) encrypted := make(Frame, FrameLen) - encryptor.XORKeyStream(encrypted, *clientFrame) - copy(encrypted[:56], (*clientFrame)[:56]) + encryptor.XORKeyStream(encrypted, clientFrame) + copy(encrypted[:56], clientFrame[:56]) invertedClientFrame := clientFrame.Invert() clientHasher = sha256.New() @@ -63,7 +63,7 @@ func TestObfs2Full(t *testing.T) { invertedClientKey := clientHasher.Sum(nil) clientDecryptor := makeStreamCipher(invertedClientKey, invertedClientFrame.IV()) - clientObfs, _, err := ParseObfuscated2ClientFrame(secret, &encrypted) + clientObfs, _, err := ParseObfuscated2ClientFrame(secret, encrypted) assert.Nil(t, err) connOpts := &mtproto.ConnectionOpts{ @@ -73,7 +73,7 @@ func TestObfs2Full(t *testing.T) { tgObfs, tgFrame := MakeTelegramObfuscated2Frame(connOpts) tgDecryptor := makeStreamCipher(tgFrame.Key(), tgFrame.IV()) decrypted := make(Frame, FrameLen) - tgDecryptor.XORKeyStream(decrypted, *tgFrame) + tgDecryptor.XORKeyStream(decrypted, tgFrame) _, err = decrypted.ConnectionType() assert.Nil(t, err) diff --git a/proxy/copy_pool.go b/proxy/copy_pool.go deleted file mode 100644 index 23477aa85..000000000 --- a/proxy/copy_pool.go +++ /dev/null @@ -1,18 +0,0 @@ -package proxy - -import ( - "sync" - - "github.com/9seconds/mtg/config" -) - -var copyPool sync.Pool - -func init() { - copyPool = sync.Pool{ - New: func() interface{} { - data := make([]byte, config.BufferSizeCopy) - return &data - }, - } -} diff --git a/proxy/proxy.go b/proxy/proxy.go new file mode 100644 index 000000000..3fba774f8 --- /dev/null +++ b/proxy/proxy.go @@ -0,0 +1,162 @@ +package proxy + +import ( + "io" + "net" + "sync" + + "github.com/juju/errors" + uuid "github.com/satori/go.uuid" + "go.uber.org/zap" + + "github.com/9seconds/mtg/client" + "github.com/9seconds/mtg/config" + "github.com/9seconds/mtg/mtproto" + "github.com/9seconds/mtg/stats" + "github.com/9seconds/mtg/telegram" + "github.com/9seconds/mtg/wrappers" +) + +// Proxy is a core of this program. +type Proxy struct { + clientInit client.Init + tg telegram.Telegram + conf *config.Config +} + +// Serve runs TCP proxy server. +func (p *Proxy) Serve() error { + lsock, err := net.Listen("tcp", p.conf.BindAddr()) + if err != nil { + return errors.Annotate(err, "Cannot create listen socket") + } + + for { + if conn, err := lsock.Accept(); err != nil { + zap.S().Errorw("Cannot allocate incoming connection", "error", err) + } else { + go p.accept(conn) + } + } +} + +func (p *Proxy) accept(conn net.Conn) { + connID := uuid.NewV4().String() + log := zap.S().With("connection_id", connID).Named("main") + + defer func() { + conn.Close() // nolint: errcheck + + if err := recover(); err != nil { + stats.NewCrash() + log.Errorw("Crash of accept handler", "error", err) + } + }() + + log.Infow("Client connected", "addr", conn.RemoteAddr()) + + clientConn, opts, err := p.clientInit(conn, connID, p.conf) + if err != nil { + log.Errorw("Cannot initialize client connection", "error", err) + return + } + defer clientConn.(io.Closer).Close() // nolint: errcheck + + stats.ClientConnected(opts.ConnectionType, clientConn.RemoteAddr()) + defer stats.ClientDisconnected(opts.ConnectionType, clientConn.RemoteAddr()) + + serverConn, err := p.getTelegramConn(opts, connID) + if err != nil { + log.Errorw("Cannot initialize server connection", "error", err) + return + } + defer serverConn.(io.Closer).Close() // nolint: errcheck + + wait := &sync.WaitGroup{} + wait.Add(2) + + if p.conf.UseMiddleProxy() { + clientPacket := clientConn.(wrappers.PacketReadWriteCloser) + serverPacket := serverConn.(wrappers.PacketReadWriteCloser) + go p.middlePipe(clientPacket, serverPacket, wait, &opts.ReadHacks) + go p.middlePipe(serverPacket, clientPacket, wait, &opts.WriteHacks) + } else { + clientStream := clientConn.(wrappers.StreamReadWriteCloser) + serverStream := serverConn.(wrappers.StreamReadWriteCloser) + go p.directPipe(clientStream, serverStream, wait) + go p.directPipe(serverStream, clientStream, wait) + } + + wait.Wait() + + log.Infow("Client disconnected", "addr", conn.RemoteAddr()) +} + +func (p *Proxy) getTelegramConn(opts *mtproto.ConnectionOpts, connID string) (wrappers.Wrap, error) { + streamConn, err := p.tg.Dial(connID, opts) + if err != nil { + return nil, errors.Annotate(err, "Cannot dial to Telegram") + } + + packetConn, err := p.tg.Init(opts, streamConn) + if err != nil { + return nil, errors.Annotate(err, "Cannot handshake telegram") + } + + return packetConn, nil +} + +func (p *Proxy) middlePipe(src wrappers.PacketReadCloser, dst io.WriteCloser, wait *sync.WaitGroup, hacks *mtproto.Hacks) { + defer func() { + src.Close() // nolint: errcheck + dst.Close() // nolint: errcheck + wait.Done() + }() + + for { + hacks.SimpleAck = false + hacks.QuickAck = false + + packet, err := src.Read() + if err != nil { + src.Logger().Warnw("Cannot read packet", "error", err) + return + } + if _, err = dst.Write(packet); err != nil { + src.Logger().Warnw("Cannot write packet", "error", err) + return + } + } +} + +func (p *Proxy) directPipe(src wrappers.StreamReadCloser, dst io.WriteCloser, wait *sync.WaitGroup) { + defer func() { + src.Close() // nolint: errcheck + dst.Close() // nolint: errcheck + wait.Done() + }() + + if _, err := io.Copy(dst, src); err != nil { + src.Logger().Warnw("Cannot pump sockets", "error", err) + } +} + +// NewProxy returns new proxy instance. +func NewProxy(conf *config.Config) *Proxy { + var clientInit client.Init + var tg telegram.Telegram + + if conf.UseMiddleProxy() { + clientInit = client.MiddleInit + tg = telegram.NewMiddleTelegram(conf) + } else { + clientInit = client.DirectInit + tg = telegram.NewDirectTelegram(conf) + } + + return &Proxy{ + conf: conf, + clientInit: clientInit, + tg: tg, + } +} diff --git a/proxy/server.go b/proxy/server.go deleted file mode 100644 index 1a6ea73bc..000000000 --- a/proxy/server.go +++ /dev/null @@ -1,149 +0,0 @@ -package proxy - -import ( - "context" - "io" - "net" - "sync" - - "github.com/juju/errors" - uuid "github.com/satori/go.uuid" - "go.uber.org/zap" - - "github.com/9seconds/mtg/client" - "github.com/9seconds/mtg/config" - "github.com/9seconds/mtg/mtproto" - "github.com/9seconds/mtg/telegram" - "github.com/9seconds/mtg/wrappers" -) - -// Server is an insgtance of MTPROTO proxy. -type Server struct { - conf *config.Config - logger *zap.SugaredLogger - stats *Stats - tg telegram.Telegram - clientInit client.Init -} - -// Serve does MTPROTO proxying. -func (s *Server) Serve() error { - lsock, err := net.Listen("tcp", s.conf.BindAddr()) - if err != nil { - return errors.Annotate(err, "Cannot create listen socket") - } - - for { - if conn, err := lsock.Accept(); err != nil { - s.logger.Warn("Cannot allocate incoming connection", "error", err) - } else { - go s.accept(conn) - } - } -} - -func (s *Server) accept(conn net.Conn) { - defer func() { - s.stats.closeConnection() - conn.Close() // nolint: errcheck - - if r := recover(); r != nil { - s.logger.Errorw("Crash of accept handler", "error", r) - } - }() - - s.stats.newConnection() - ctx, cancel := context.WithCancel(context.Background()) - socketID := uuid.NewV4().String() - - s.logger.Debugw("Client connected", - "addr", conn.RemoteAddr().String(), - "socketid", socketID, - ) - - connOpts, clientConn, err := s.getClientStream(ctx, cancel, conn, socketID) - if err != nil { - s.logger.Warnw("Cannot initialize client connection", - "addr", conn.RemoteAddr().String(), - "socketid", socketID, - "error", err, - ) - return - } - defer clientConn.Close() // nolint: errcheck - - tgConn, err := s.getTelegramStream(ctx, cancel, connOpts, socketID) - if err != nil { - s.logger.Warnw("Cannot initialize Telegram connection", - "socketid", socketID, - "error", err, - ) - return - } - defer tgConn.Close() // nolint: errcheck - - wait := &sync.WaitGroup{} - wait.Add(2) - - go s.pipe(clientConn, tgConn, wait) - go s.pipe(tgConn, clientConn, wait) - - <-ctx.Done() - wait.Wait() - - s.logger.Debugw("Client disconnected", - "addr", conn.RemoteAddr().String(), - "socketid", socketID, - ) -} - -func (s *Server) getClientStream(ctx context.Context, cancel context.CancelFunc, conn net.Conn, socketID string) (*mtproto.ConnectionOpts, io.ReadWriteCloser, error) { - connOpts, socket, err := s.clientInit(conn, s.conf) - if err != nil { - return nil, nil, errors.Annotate(err, "Cannot init client connection") - } - - socket = wrappers.NewTrafficRWC(socket, s.stats.addIncomingTraffic, s.stats.addOutgoingTraffic) - socket = wrappers.NewLogRWC(socket, s.logger, socketID, "client") - socket = wrappers.NewCtxRWC(ctx, cancel, socket) - - return connOpts, socket, nil -} - -func (s *Server) getTelegramStream(ctx context.Context, cancel context.CancelFunc, connOpts *mtproto.ConnectionOpts, socketID string) (io.ReadWriteCloser, error) { - conn, err := s.tg.Dial(connOpts) - if err != nil { - return nil, errors.Annotate(err, "Cannot connect to Telegram") - } - - conn = wrappers.NewTrafficRWC(conn, s.stats.addIncomingTraffic, s.stats.addOutgoingTraffic) - conn, err = s.tg.Init(connOpts, conn) - if err != nil { - return nil, errors.Annotate(err, "Cannot handshake Telegram") - } - - conn = wrappers.NewLogRWC(conn, s.logger, socketID, "telegram") - conn = wrappers.NewCtxRWC(ctx, cancel, conn) - - return conn, nil -} - -func (s *Server) pipe(dst io.Writer, src io.Reader, wait *sync.WaitGroup) { - defer wait.Done() - - buf := copyPool.Get().(*[]byte) - defer copyPool.Put(buf) - - io.CopyBuffer(dst, src, *buf) // nolint: errcheck -} - -// NewServer creates new instance of MTPROTO proxy. -func NewServer(conf *config.Config, logger *zap.SugaredLogger, stat *Stats) *Server { - return &Server{ - conf: conf, - logger: logger, - stats: stat, - tg: telegram.NewDirectTelegram(conf), - clientInit: client.DirectInit, - } -} diff --git a/proxy/stats.go b/proxy/stats.go deleted file mode 100644 index 9469c2f7d..000000000 --- a/proxy/stats.go +++ /dev/null @@ -1,74 +0,0 @@ -package proxy - -import ( - "encoding/json" - "net/http" - "strconv" - "sync/atomic" - "time" - - "github.com/9seconds/mtg/config" -) - -type statsUptime time.Time - -func (s statsUptime) MarshalJSON() ([]byte, error) { - uptime := int(time.Since(time.Time(s)).Seconds()) - return []byte(strconv.Itoa(uptime)), nil -} - -// Stats is a datastructure for statistics on work of this proxy. -type Stats struct { - AllConnections uint64 `json:"all_connections"` - ActiveConnections uint32 `json:"active_connections"` - Traffic struct { - Incoming uint64 `json:"incoming"` - Outgoing uint64 `json:"outgoing"` - } `json:"traffic"` - URLs config.IPURLs `json:"urls"` - Uptime statsUptime `json:"uptime"` - - conf *config.Config -} - -func (s *Stats) newConnection() { - atomic.AddUint64(&s.AllConnections, 1) - atomic.AddUint32(&s.ActiveConnections, 1) -} - -func (s *Stats) closeConnection() { - atomic.AddUint32(&s.ActiveConnections, ^uint32(0)) -} - -func (s *Stats) addIncomingTraffic(n int) { - atomic.AddUint64(&s.Traffic.Incoming, uint64(n)) -} - -func (s *Stats) addOutgoingTraffic(n int) { - atomic.AddUint64(&s.Traffic.Outgoing, uint64(n)) -} - -// Serve runs statistics HTTP server. -func (s *Stats) Serve() { - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - encoder := json.NewEncoder(w) - encoder.SetEscapeHTML(false) - encoder.SetIndent("", " ") - encoder.Encode(s) // nolint: errcheck, gas - }) - - http.ListenAndServe(s.conf.StatAddr(), nil) // nolint: errcheck, gas -} - -// NewStats returns new instance of statistics datastructure. -func NewStats(conf *config.Config) *Stats { - stat := &Stats{ - Uptime: statsUptime(time.Now()), - conf: conf, - } - stat.URLs = conf.GetURLs() - - return stat -} diff --git a/run-mtg.sh b/run-mtg.sh index 96d72d340..60459c0e2 100755 --- a/run-mtg.sh +++ b/run-mtg.sh @@ -21,7 +21,6 @@ docker run \ --sysctl net.ipv4.tcp_congestion_control=bbr \ --sysctl net.ipv4.tcp_fastopen=3 \ --sysctl net.ipv4.tcp_fin_timeout=30 \ - --sysctl net.ipv4.tcp_keepalive_time=1200 \ --sysctl net.ipv4.tcp_max_syn_backlog=4096 \ --sysctl net.ipv4.tcp_max_tw_buckets=5000 \ --sysctl net.ipv4.tcp_mtu_probing=1 \ diff --git a/stats/channels.go b/stats/channels.go new file mode 100644 index 000000000..ee2911bac --- /dev/null +++ b/stats/channels.go @@ -0,0 +1,151 @@ +package stats + +import ( + "net" + "time" + + "github.com/9seconds/mtg/mtproto" +) + +const ( + crashesChanLength = 1 + connectionsChanLength = 20 + trafficChanLength = 5000 +) + +var ( + crashesChan = make(chan struct{}, crashesChanLength) + connectionsChan = make(chan *connectionData, connectionsChanLength) + trafficChan = make(chan *trafficData, trafficChanLength) +) + +type connectionData struct { + connectionType mtproto.ConnectionType + connected bool + addr *net.TCPAddr +} + +type trafficData struct { + traffic int + ingress bool +} + +func crashManager() { + for range crashesChan { + instance.mutex.RLock() + + instance.Crashes++ + + instance.mutex.RUnlock() + } +} + +func connectionManager() { // nolint: gocyclo + for event := range connectionsChan { + instance.mutex.RLock() + + isIPv4 := event.addr.IP.To4() != nil + var inc uint32 = 1 + if !event.connected { + inc = ^uint32(0) + } + + switch event.connectionType { + case mtproto.ConnectionTypeAbridged: + if isIPv4 { + instance.ActiveConnections.Abridged.IPv4 += inc + if event.connected { + instance.AllConnections.Abridged.IPv4 += inc + } + } else { + instance.ActiveConnections.Abridged.IPv6 += inc + if event.connected { + instance.AllConnections.Abridged.IPv6 += inc + } + } + default: + if isIPv4 { + instance.ActiveConnections.Intermediate.IPv4 += inc + if event.connected { + instance.AllConnections.Intermediate.IPv4 += inc + } + } else { + instance.ActiveConnections.Intermediate.IPv6 += inc + if event.connected { + instance.AllConnections.Intermediate.IPv6 += inc + } + } + } + + instance.mutex.RUnlock() + } +} + +func trafficManager() { + speedChan := time.Tick(time.Second) + + for { + select { + case event := <-trafficChan: + instance.mutex.RLock() + + if event.ingress { + instance.Traffic.Ingress += trafficValue(event.traffic) + instance.speedCurrent.Ingress += trafficSpeedValue(event.traffic) + } else { + instance.Traffic.Egress += trafficValue(event.traffic) + instance.speedCurrent.Egress += trafficSpeedValue(event.traffic) + } + + instance.mutex.RUnlock() + case <-speedChan: + instance.mutex.RLock() + + instance.Speed.Ingress = instance.speedCurrent.Ingress + instance.Speed.Egress = instance.speedCurrent.Egress + instance.speedCurrent.Ingress = trafficSpeedValue(0) + instance.speedCurrent.Egress = trafficSpeedValue(0) + + instance.mutex.RUnlock() + } + } +} + +// NewCrash indicates new crash. +func NewCrash() { + crashesChan <- struct{}{} +} + +// ClientConnected indicates that new client was connected. +func ClientConnected(connectionType mtproto.ConnectionType, addr *net.TCPAddr) { + connectionsChan <- &connectionData{ + connectionType: connectionType, + addr: addr, + connected: true, + } +} + +// ClientDisconnected indicates that client was disconnected. +func ClientDisconnected(connectionType mtproto.ConnectionType, addr *net.TCPAddr) { + connectionsChan <- &connectionData{ + connectionType: connectionType, + addr: addr, + connected: false, + } +} + +// IngressTraffic accounts new ingress traffic. +func IngressTraffic(traffic int) { + trafficChan <- &trafficData{ + traffic: traffic, + ingress: true, + } +} + +// EgressTraffic accounts new ingress traffic. +func EgressTraffic(traffic int) { + trafficChan <- &trafficData{ + traffic: traffic, + ingress: false, + } +} diff --git a/stats/server.go b/stats/server.go new file mode 100644 index 000000000..a5369e4d0 --- /dev/null +++ b/stats/server.go @@ -0,0 +1,57 @@ +package stats + +import ( + "encoding/json" + "net/http" + "sync" + "time" + + "go.uber.org/zap" + + "github.com/9seconds/mtg/config" +) + +var instance *stats + +// Start starts new statisitcs server. +func Start(conf *config.Config) { + log := zap.S().Named("stats") + + instance = &stats{ + URLs: conf.GetURLs(), + Uptime: uptime(time.Now()), + mutex: &sync.RWMutex{}, + } + + go crashManager() + go connectionManager() + go trafficManager() + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + instance.mutex.Lock() + first, err := json.Marshal(instance) + instance.mutex.Unlock() + + if err != nil { + log.Errorw("Cannot encode json", "error", err) + http.Error(w, "Internal server error", 500) + return + } + + interm := map[string]interface{}{} + json.Unmarshal(first, &interm) // nolint: errcheck + + encoder := json.NewEncoder(w) + encoder.SetEscapeHTML(false) + encoder.SetIndent("", " ") + if err = encoder.Encode(interm); err != nil { + log.Errorw("Cannot encode json", "error", err) + } + }) + + if err := http.ListenAndServe(conf.StatAddr(), nil); err != nil { + log.Fatalw("Stats server has been stopped", "error", err) + } +} diff --git a/stats/stats.go b/stats/stats.go new file mode 100644 index 000000000..5c6099ccd --- /dev/null +++ b/stats/stats.go @@ -0,0 +1,100 @@ +package stats + +import ( + "encoding/json" + "fmt" + "strconv" + "sync" + "time" + + humanize "github.com/dustin/go-humanize" + + "github.com/9seconds/mtg/config" +) + +type uptime time.Time + +func (u uptime) MarshalJSON() ([]byte, error) { + duration := time.Since(time.Time(u)) + value := map[string]string{ + "seconds": strconv.Itoa(int(duration.Seconds())), + "human": humanize.Time(time.Time(u)), + } + + return json.Marshal(value) +} + +type trafficValue uint64 + +func (t trafficValue) MarshalJSON() ([]byte, error) { + tv := uint64(t) + value := map[string]interface{}{ + "bytes": tv, + "human": humanize.Bytes(tv), + } + + return json.Marshal(value) +} + +type trafficSpeedValue uint64 + +func (t trafficSpeedValue) MarshalJSON() ([]byte, error) { + speed := uint64(t) + value := map[string]interface{}{ + "bytes/s": speed, + "human": fmt.Sprintf("%s/S", humanize.Bytes(speed)), + } + + return json.Marshal(value) +} + +type connections struct { + All connectionType `json:"all"` + Abridged connectionType `json:"abridged"` + Intermediate connectionType `json:"intermediate"` +} + +func (c connections) MarshalJSON() ([]byte, error) { + c.All.IPv4 = c.Abridged.IPv4 + c.Intermediate.IPv4 + c.All.IPv6 = c.Abridged.IPv6 + c.Intermediate.IPv6 + + value := struct { + All connectionType `json:"all"` + Abridged connectionType `json:"abridged"` + Intermediate connectionType `json:"intermediate"` + }{ + All: c.All, + Abridged: c.Abridged, + Intermediate: c.Intermediate, + } + + return json.Marshal(value) +} + +type connectionType struct { + IPv6 uint32 `json:"ipv6"` + IPv4 uint32 `json:"ipv4"` +} + +type traffic struct { + Ingress trafficValue `json:"ingress"` + Egress trafficValue `json:"egress"` +} + +type speed struct { + Ingress trafficSpeedValue `json:"ingress"` + Egress trafficSpeedValue `json:"egress"` +} + +type stats struct { + URLs config.IPURLs `json:"urls"` + ActiveConnections connections `json:"active_connections"` + AllConnections connections `json:"all_connections"` + Traffic traffic `json:"traffic"` + Speed speed `json:"speed"` + Uptime uptime `json:"uptime"` + Crashes uint32 `json:"crashes"` + + speedCurrent speed + mutex *sync.RWMutex +} diff --git a/telegram/dialer.go b/telegram/dialer.go index ef51f5021..58b4ff522 100644 --- a/telegram/dialer.go +++ b/telegram/dialer.go @@ -1,21 +1,25 @@ package telegram import ( - "io" "net" "time" "github.com/juju/errors" "github.com/9seconds/mtg/config" + "github.com/9seconds/mtg/wrappers" ) const ( telegramDialTimeout = 10 * time.Second + readBufferSize = 64 * 1024 + writeBufferSize = 64 * 1024 ) type tgDialer struct { net.Dialer + + conf *config.Config } func (t *tgDialer) dial(addr string) (net.Conn, error) { @@ -23,18 +27,27 @@ func (t *tgDialer) dial(addr string) (net.Conn, error) { if err != nil { return nil, errors.Annotate(err, "Cannot connect to Telegram") } - if err = config.SetSocketOptions(conn); err != nil { - return nil, errors.Annotate(err, "Cannot set socket options") + + tcpSocket := conn.(*net.TCPConn) + if err = tcpSocket.SetNoDelay(true); err != nil { + return nil, errors.Annotate(err, "Cannot set NO_DELAY to Telegram") + } + if err = tcpSocket.SetReadBuffer(readBufferSize); err != nil { + return nil, errors.Annotate(err, "Cannot set read buffer size on telegram socket") + } + if err = tcpSocket.SetWriteBuffer(writeBufferSize); err != nil { + return nil, errors.Annotate(err, "Cannot set write buffer size on telegram socket") } return conn, nil } -func (t *tgDialer) dialRWC(addr string) (io.ReadWriteCloser, error) { +func (t *tgDialer) dialRWC(addr, connID string) (wrappers.StreamReadWriteCloser, error) { conn, err := t.dial(addr) if err != nil { return nil, err } + tgConn := wrappers.NewConn(conn, connID, wrappers.ConnPurposeTelegram, t.conf.PublicIPv4, t.conf.PublicIPv6) - return conn, nil + return tgConn, nil } diff --git a/telegram/direct.go b/telegram/direct.go index 89a701660..0656a50b6 100644 --- a/telegram/direct.go +++ b/telegram/direct.go @@ -1,7 +1,6 @@ package telegram import ( - "io" "net" "github.com/juju/errors" @@ -33,7 +32,7 @@ type directTelegram struct { baseTelegram } -func (t *directTelegram) Dial(connOpts *mtproto.ConnectionOpts) (io.ReadWriteCloser, error) { +func (t *directTelegram) Dial(connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) { dc := connOpts.DC if dc < 0 { dc = -dc @@ -41,25 +40,27 @@ func (t *directTelegram) Dial(connOpts *mtproto.ConnectionOpts) (io.ReadWriteClo dc = 1 } - return t.baseTelegram.dial(dc - 1) + return t.baseTelegram.dial(dc-1, connID, connOpts.ConnectionProto) } -func (t *directTelegram) Init(connOpts *mtproto.ConnectionOpts, conn io.ReadWriteCloser) (io.ReadWriteCloser, error) { +func (t *directTelegram) Init(connOpts *mtproto.ConnectionOpts, conn wrappers.StreamReadWriteCloser) (wrappers.Wrap, error) { obfs2, frame := obfuscated2.MakeTelegramObfuscated2Frame(connOpts) - defer obfuscated2.ReturnFrame(frame) - if n, err := conn.Write(*frame); err != nil || n != len(*frame) { + if _, err := conn.Write(frame); err != nil { return nil, errors.Annotate(err, "Cannot write hadnshake frame") } - return wrappers.NewStreamCipherRWC(conn, obfs2.Encryptor, obfs2.Decryptor), nil + return wrappers.NewStreamCipher(conn, obfs2.Encryptor, obfs2.Decryptor), nil } // NewDirectTelegram returns Telegram instance which connects directly // to Telegram bypassing middleproxies. func NewDirectTelegram(conf *config.Config) Telegram { return &directTelegram{baseTelegram{ - dialer: tgDialer{net.Dialer{Timeout: telegramDialTimeout}}, + dialer: tgDialer{ + Dialer: net.Dialer{Timeout: telegramDialTimeout}, + conf: conf, + }, v4Addresses: directV4Addresses, v6Addresses: directV6Addresses, }} diff --git a/telegram/middle.go b/telegram/middle.go new file mode 100644 index 000000000..1724bf01b --- /dev/null +++ b/telegram/middle.go @@ -0,0 +1,136 @@ +package telegram + +import ( + "io" + "net" + "net/http" + "sync" + + "github.com/juju/errors" + + "github.com/9seconds/mtg/config" + "github.com/9seconds/mtg/mtproto" + "github.com/9seconds/mtg/mtproto/rpc" + "github.com/9seconds/mtg/wrappers" +) + +type middleTelegram struct { + middleTelegramCaller + + conf *config.Config +} + +func (t *middleTelegram) Init(connOpts *mtproto.ConnectionOpts, conn wrappers.StreamReadWriteCloser) (wrappers.Wrap, error) { + rpcNonceConn := wrappers.NewMTProtoFrame(conn, rpc.SeqNoNonce) + + rpcNonceReq, err := t.sendRPCNonceRequest(rpcNonceConn) + if err != nil { + return nil, err + } + rpcNonceResp, err := t.receiveRPCNonceResponse(rpcNonceConn, rpcNonceReq) + if err != nil { + return nil, err + } + + secureConn := wrappers.NewMiddleProxyCipher(conn, rpcNonceReq, rpcNonceResp, t.proxySecret) + frameConn := wrappers.NewMTProtoFrame(secureConn, rpc.SeqNoHandshake) + + rpcHandshakeReq, err := t.sendRPCHandshakeRequest(frameConn) + if err != nil { + return nil, err + } + _, err = t.receiveRPCHandshakeResponse(frameConn, rpcHandshakeReq) + if err != nil { + return nil, err + } + + proxyConn, err := wrappers.NewMTProtoProxy(frameConn, connOpts, t.conf.AdTag) + if err != nil { + return nil, err + } + proxyConn.Logger().Infow("Telegram connection initialized") + + return proxyConn, nil +} + +func (t *middleTelegram) sendRPCNonceRequest(conn io.Writer) (*rpc.NonceRequest, error) { + rpcNonceReq, err := rpc.NewNonceRequest(t.proxySecret) + if err != nil { + return nil, errors.Annotate(err, "Cannot create RPC nonce request") + } + if _, err = conn.Write(rpcNonceReq.Bytes()); err != nil { + return nil, errors.Annotate(err, "Cannot send RPC nonce request") + } + + return rpcNonceReq, nil +} + +func (t *middleTelegram) receiveRPCNonceResponse(conn wrappers.PacketReader, req *rpc.NonceRequest) (*rpc.NonceResponse, error) { + packet, err := conn.Read() + if err != nil { + return nil, errors.Annotate(err, "Cannot read RPC nonce response") + } + + rpcNonceResp, err := rpc.NewNonceResponse(packet) + if err != nil { + return nil, errors.Annotate(err, "Cannot initialize RPC nonce response") + } + if err = rpcNonceResp.Valid(req); err != nil { + return nil, errors.Annotate(err, "Invalid RPC nonce response") + } + + return rpcNonceResp, nil +} + +func (t *middleTelegram) sendRPCHandshakeRequest(conn io.Writer) (*rpc.HandshakeRequest, error) { + req := rpc.NewHandshakeRequest() + if _, err := conn.Write(req.Bytes()); err != nil { + return nil, errors.Annotate(err, "Cannot send RPC handshake request") + } + + return req, nil +} + +func (t *middleTelegram) receiveRPCHandshakeResponse(conn wrappers.PacketReader, req *rpc.HandshakeRequest) (*rpc.HandshakeResponse, error) { + packet, err := conn.Read() + if err != nil { + return nil, errors.Annotate(err, "Cannot read RPC handshake response") + } + + rpcHandshakeResp, err := rpc.NewHandshakeResponse(packet) + if err != nil { + return nil, errors.Annotate(err, "Cannot initialize RPC handshake response") + } + if err = rpcHandshakeResp.Valid(req); err != nil { + return nil, errors.Annotate(err, "Invalid RPC handshake response") + } + + return rpcHandshakeResp, nil +} + +// NewMiddleTelegram creates new instance of Telegram which works with +// middle proxies. +func NewMiddleTelegram(conf *config.Config) Telegram { + tg := &middleTelegram{ + middleTelegramCaller: middleTelegramCaller{ + baseTelegram: baseTelegram{ + dialer: tgDialer{ + Dialer: net.Dialer{Timeout: telegramDialTimeout}, + conf: conf, + }, + }, + httpClient: &http.Client{ + Timeout: middleTelegramHTTPClientTimeout, + }, + dialerMutex: &sync.RWMutex{}, + }, + conf: conf, + } + + if err := tg.update(); err != nil { + panic(err) + } + go tg.autoUpdate() + + return tg +} diff --git a/telegram/middle_caller.go b/telegram/middle_caller.go new file mode 100644 index 000000000..bb5961a7d --- /dev/null +++ b/telegram/middle_caller.go @@ -0,0 +1,152 @@ +package telegram + +import ( + "bufio" + "io/ioutil" + "net" + "net/http" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/juju/errors" + "go.uber.org/zap" + + "github.com/9seconds/mtg/mtproto" + "github.com/9seconds/mtg/wrappers" +) + +const ( + middleTelegramAutoUpdateInterval = 6 * time.Hour + middleTelegramHTTPClientTimeout = 30 * time.Second + + tgAddrProxySecret = "https://core.telegram.org/getProxySecret" // nolint: gas + tgAddrProxyV4 = "https://core.telegram.org/getProxyConfig" // nolint: gas + tgAddrProxyV6 = "https://core.telegram.org/getProxyConfigV6" // nolint: gas + tgUserAgent = "mtg" +) + +var middleTelegramProxyConfigSplitter = regexp.MustCompile(`\s+`) + +type middleTelegramCaller struct { + baseTelegram + + proxySecret []byte + dialerMutex *sync.RWMutex + httpClient *http.Client +} + +func (t *middleTelegramCaller) Dial(connID string, connOpts *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) { + dc := connOpts.DC + if dc == 0 { + dc = 1 + } + t.dialerMutex.RLock() + defer t.dialerMutex.RUnlock() + + return t.baseTelegram.dial(dc, connID, connOpts.ConnectionProto) +} + +func (t *middleTelegramCaller) autoUpdate() { + for range time.Tick(middleTelegramAutoUpdateInterval) { + if err := t.update(); err != nil { + zap.S().Warnw("Cannot update from Telegram", "error", err) + } + } +} + +func (t *middleTelegramCaller) update() error { + secret, err := t.getTelegramProxySecret() + if err != nil { + return errors.Annotate(err, "Cannot get proxy secret") + } + + v4Addresses, err := t.getTelegramAddresses(tgAddrProxyV4) + if err != nil { + return errors.Annotate(err, "Cannot get ipv4 addresses") + } + + v6Addresses, err := t.getTelegramAddresses(tgAddrProxyV6) + if err != nil { + return errors.Annotate(err, "Cannot get ipv6 addresses") + } + + t.dialerMutex.Lock() + t.proxySecret = secret + t.v4Addresses = v4Addresses + t.v6Addresses = v6Addresses + t.dialerMutex.Unlock() + + zap.S().Infow("Telegram middle proxy data has been updated") + + return nil +} + +func (t *middleTelegramCaller) getTelegramProxySecret() ([]byte, error) { + resp, err := t.call(tgAddrProxySecret) + if err != nil { + return nil, errors.Annotate(err, "Cannot access telegram server") + } + defer resp.Body.Close() // nolint: errcheck + + secret, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errors.Annotate(err, "Cannot read response") + } + + return secret, nil +} + +func (t *middleTelegramCaller) getTelegramAddresses(url string) (map[int16][]string, error) { + resp, err := t.call(url) + if err != nil { + return nil, errors.Annotate(err, "Cannot access telegram server") + } + defer resp.Body.Close() // nolint: errcheck + + scanner := bufio.NewScanner(resp.Body) + data := map[int16][]string{} + for scanner.Scan() { + text := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(text, "#") { + continue + } + + chunks := middleTelegramProxyConfigSplitter.Split(text, 3) + if len(chunks) != 3 || chunks[0] != "proxy_for" { + return nil, errors.Errorf("Incorrect config '%s'", text) + } + dcIdx64, err2 := strconv.ParseInt(chunks[1], 10, 16) + if err2 != nil { + return nil, errors.Errorf("Incorrect config '%s'", text) + } + dcIdx := int16(dcIdx64) + + addr := strings.TrimRight(chunks[2], ";") + if _, _, err2 = net.SplitHostPort(addr); err != nil { + return nil, errors.Annotatef(err2, "Incorrect config '%s'", text) + } + + if addresses, ok := data[dcIdx]; ok { + data[dcIdx] = append(addresses, addr) + } else { + data[dcIdx] = []string{addr} + } + } + err = scanner.Err() + if err != nil { + return nil, errors.Annotate(err, "Cannot read response from the telegram") + } + + return data, nil +} + +func (t *middleTelegramCaller) call(url string) (*http.Response, error) { + req, _ := http.NewRequest("GET", url, nil) + req.Header.Set("Accept", "text/plain") + req.Header.Set("User-Agent", tgUserAgent) + + return t.httpClient.Do(req) +} diff --git a/telegram/telegram.go b/telegram/telegram.go index 0f71efcd0..b6b34ef9b 100644 --- a/telegram/telegram.go +++ b/telegram/telegram.go @@ -1,20 +1,18 @@ package telegram import ( - "io" "math/rand" "github.com/juju/errors" "github.com/9seconds/mtg/mtproto" + "github.com/9seconds/mtg/wrappers" ) -// Telegram defines an interface to connect to Telegram. This -// encapsulates logic of working with middleproxies or direct -// connections. +// Telegram is an interface for different Telegram work modes. type Telegram interface { - Dial(*mtproto.ConnectionOpts) (io.ReadWriteCloser, error) - Init(*mtproto.ConnectionOpts, io.ReadWriteCloser) (io.ReadWriteCloser, error) + Dial(string, *mtproto.ConnectionOpts) (wrappers.StreamReadWriteCloser, error) + Init(*mtproto.ConnectionOpts, wrappers.StreamReadWriteCloser) (wrappers.Wrap, error) } type baseTelegram struct { @@ -24,17 +22,22 @@ type baseTelegram struct { v6Addresses map[int16][]string } -func (b *baseTelegram) dial(dcIdx int16) (io.ReadWriteCloser, error) { +func (b *baseTelegram) dial(dcIdx int16, connID string, proto mtproto.ConnectionProtocol) (wrappers.StreamReadWriteCloser, error) { addrs := make([]string, 2) - if addr, ok := b.v6Addresses[dcIdx]; ok && len(addr) > 0 { - addrs = append(addrs, addr[rand.Intn(len(addr))]) + + if proto&mtproto.ConnectionProtocolIPv6 != 0 { + if addr, ok := b.v6Addresses[dcIdx]; ok && len(addr) > 0 { + addrs = append(addrs, addr[rand.Intn(len(addr))]) + } } - if addr, ok := b.v4Addresses[dcIdx]; ok && len(addr) > 0 { - addrs = append(addrs, addr[rand.Intn(len(addr))]) + if proto&mtproto.ConnectionProtocolIPv4 != 0 { + if addr, ok := b.v4Addresses[dcIdx]; ok && len(addr) > 0 { + addrs = append(addrs, addr[rand.Intn(len(addr))]) + } } for _, addr := range addrs { - if conn, err := b.dialer.dialRWC(addr); err == nil { + if conn, err := b.dialer.dialRWC(addr, connID); err == nil { return conn, err } } diff --git a/utils/read_current_data.go b/utils/read_current_data.go new file mode 100644 index 000000000..284369c43 --- /dev/null +++ b/utils/read_current_data.go @@ -0,0 +1,21 @@ +package utils + +import "io" + +const readCurrentDataBufferSize = 1024 + 1 // + 1 because telegram operates with blocks mod 4 + +// ReadCurrentData reads all data from io.Reader which is ready to be read. +func ReadCurrentData(src io.Reader) (rv []byte, err error) { + buf := make([]byte, readCurrentDataBufferSize) + n := readCurrentDataBufferSize + + for n == len(buf) { + n, err = src.Read(buf) + if err != nil { + return nil, err + } + rv = append(rv, buf[:n]...) + } + + return rv, nil +} diff --git a/utils/reverse_bytes.go b/utils/reverse_bytes.go new file mode 100644 index 000000000..ab7cd5b89 --- /dev/null +++ b/utils/reverse_bytes.go @@ -0,0 +1,15 @@ +package utils + +// ReverseBytes is a common slice reverser. +func ReverseBytes(data []byte) []byte { + dataLen := len(data) + rv := make([]byte, dataLen) + + rv[dataLen/2] = data[dataLen/2] + for i := dataLen/2 - 1; i >= 0; i-- { + opp := dataLen - i - 1 + rv[i], rv[opp] = data[opp], data[i] + } + + return rv +} diff --git a/utils/uint24.go b/utils/uint24.go new file mode 100644 index 000000000..66c5f1bb8 --- /dev/null +++ b/utils/uint24.go @@ -0,0 +1,15 @@ +package utils + +// Uint24 is a replacement for the absent Go uint24 data type. +// This data type is little endian. +type Uint24 [3]byte + +// ToUint24 converts number to Uint24. +func ToUint24(number uint32) Uint24 { + return Uint24{byte(number), byte(number >> 8), byte(number >> 16)} +} + +// FromUint24 converts Uint24 to number. +func FromUint24(number Uint24) uint32 { + return uint32(number[0]) + (uint32(number[1]) << 8) + (uint32(number[2]) << 16) +} diff --git a/wrappers/blockcipher.go b/wrappers/blockcipher.go new file mode 100644 index 000000000..39fe7e899 --- /dev/null +++ b/wrappers/blockcipher.go @@ -0,0 +1,99 @@ +package wrappers + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "net" + + "go.uber.org/zap" + + "github.com/9seconds/mtg/utils" + "github.com/juju/errors" +) + +// BlockCipher is a stream writer which encrypts/decrypts blocks of data +// with AES CBC. This also is buffered reader. It means, that block +// reading is transparent for it, you can assume you are working with +// good old io.Reader. +type BlockCipher struct { + buf *bytes.Buffer + + logger *zap.SugaredLogger + conn StreamReadWriteCloser + encryptor cipher.BlockMode + decryptor cipher.BlockMode +} + +func (b *BlockCipher) Read(p []byte) (int, error) { + if b.buf.Len() > 0 { + return b.flush(p) + } + + buf := []byte{} + for len(buf) == 0 || len(buf)%aes.BlockSize != 0 { + rv, err := utils.ReadCurrentData(b.conn) + if err != nil { + return 0, errors.Annotate(err, "Cannot read from socket") + } + buf = append(buf, rv...) + } + + b.decryptor.CryptBlocks(buf, buf) + b.buf.Write(buf) + + return b.flush(p) +} + +func (b *BlockCipher) flush(p []byte) (int, error) { + if b.buf.Len() <= len(p) { + sizeToReturn := b.buf.Len() + copy(p, b.buf.Bytes()) + b.buf.Reset() + return sizeToReturn, nil + } + + return b.buf.Read(p) +} + +func (b *BlockCipher) Write(p []byte) (int, error) { + if len(p)%aes.BlockSize > 0 { + return 0, errors.Errorf("Incorrect block size %d", len(p)) + } + + encrypted := make([]byte, len(p)) + b.encryptor.CryptBlocks(encrypted, p) + + return b.conn.Write(encrypted) +} + +// Logger returns an instance of the logger for this wrapper. +func (b *BlockCipher) Logger() *zap.SugaredLogger { + return b.logger +} + +// LocalAddr returns local address of the underlying net.Conn. +func (b *BlockCipher) LocalAddr() *net.TCPAddr { + return b.conn.LocalAddr() +} + +// RemoteAddr returns remote address of the underlying net.Conn. +func (b *BlockCipher) RemoteAddr() *net.TCPAddr { + return b.conn.RemoteAddr() +} + +// Close closes underlying net.Conn. +func (b *BlockCipher) Close() error { + return b.conn.Close() +} + +// NewBlockCipher creates new instance of BlockCipher based on given data. +func NewBlockCipher(conn StreamReadWriteCloser, encryptor, decryptor cipher.BlockMode) StreamReadWriteCloser { + return &BlockCipher{ + buf: &bytes.Buffer{}, + conn: conn, + logger: conn.Logger().Named("block-cipher"), + encryptor: encryptor, + decryptor: decryptor, + } +} diff --git a/wrappers/buffer_pool.go b/wrappers/buffer_pool.go deleted file mode 100644 index ead700dfc..000000000 --- a/wrappers/buffer_pool.go +++ /dev/null @@ -1,27 +0,0 @@ -package wrappers - -import ( - "bytes" - "sync" -) - -var bufPool sync.Pool - -func getBuffer() *bytes.Buffer { - buf := bufPool.Get().(*bytes.Buffer) - buf.Reset() - - return buf -} - -func putBuffer(buf *bytes.Buffer) { - bufPool.Put(buf) -} - -func init() { - bufPool = sync.Pool{ - New: func() interface{} { - return &bytes.Buffer{} - }, - } -} diff --git a/wrappers/conn.go b/wrappers/conn.go new file mode 100644 index 000000000..408483a59 --- /dev/null +++ b/wrappers/conn.go @@ -0,0 +1,121 @@ +package wrappers + +import ( + "net" + "time" + + "go.uber.org/zap" + + "github.com/9seconds/mtg/stats" +) + +// ConnPurpose is intented to be identifier of connection purpose. We +// sometimes want to treat client/telegram connection differently (for +// logging for example). +type ConnPurpose uint8 + +func (c ConnPurpose) String() string { + switch c { + case ConnPurposeClient: + return "client" + case ConnPurposeTelegram: + return "telegram" + } + + return "" +} + +// ConnPurpose* define different connection types. +const ( + ConnPurposeClient = iota + ConnPurposeTelegram +) + +const ( + connTimeoutRead = 2 * time.Minute + connTimeoutWrite = 2 * time.Minute +) + +// Conn is a basic wrapper for net.Conn providing the most low-level +// logic and management as possible. +type Conn struct { + connID string + conn net.Conn + logger *zap.SugaredLogger + + publicIPv4 net.IP + publicIPv6 net.IP +} + +func (c *Conn) Write(p []byte) (int, error) { + c.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite)) // nolint: errcheck + n, err := c.conn.Write(p) + + c.logger.Debugw("Write to stream", "bytes", n, "error", err) + stats.EgressTraffic(n) + + return n, err +} + +func (c *Conn) Read(p []byte) (int, error) { + c.conn.SetReadDeadline(time.Now().Add(connTimeoutRead)) // nolint: errcheck + n, err := c.conn.Read(p) + + c.logger.Debugw("Read from stream", "bytes", n, "error", err) + stats.IngressTraffic(n) + + return n, err +} + +// Close closes underlying net.Conn instance. +func (c *Conn) Close() error { + defer c.logger.Debugw("Close connection") + return c.conn.Close() +} + +// Logger returns an instance of the logger for this wrapper. +func (c *Conn) Logger() *zap.SugaredLogger { + return c.logger +} + +// LocalAddr returns local address of the underlying net.Conn. +func (c *Conn) LocalAddr() *net.TCPAddr { + addr := c.conn.LocalAddr().(*net.TCPAddr) + newAddr := *addr + + if c.RemoteAddr().IP.To4() != nil { + if c.publicIPv4 != nil { + newAddr.IP = c.publicIPv4 + } + } else if c.publicIPv6 != nil { + newAddr.IP = c.publicIPv6 + } + + return &newAddr +} + +// RemoteAddr returns remote address of the underlying net.Conn. +func (c *Conn) RemoteAddr() *net.TCPAddr { + return c.conn.RemoteAddr().(*net.TCPAddr) +} + +// NewConn initializes Conn wrapper for net.Conn. +func NewConn(conn net.Conn, connID string, purpose ConnPurpose, publicIPv4, publicIPv6 net.IP) StreamReadWriteCloser { + logger := zap.S().With( + "connection_id", connID, + "local_address", conn.LocalAddr(), + "remote_address", conn.RemoteAddr(), + "purpose", purpose, + ).Named("conn") + + wrapper := Conn{ + logger: logger, + connID: connID, + conn: conn, + publicIPv4: publicIPv4, + publicIPv6: publicIPv6, + } + wrapper.logger = logger.With("faked_local_addr", wrapper.LocalAddr()) + + return &wrapper +} diff --git a/wrappers/ctxrwc.go b/wrappers/ctxrwc.go deleted file mode 100644 index 26f47be60..000000000 --- a/wrappers/ctxrwc.go +++ /dev/null @@ -1,59 +0,0 @@ -package wrappers - -import ( - "context" - "io" - - "github.com/juju/errors" -) - -// CtxReadWriteCloser wraps underlying connection and does management of the -// context and its cancel function. -type CtxReadWriteCloser struct { - ctx context.Context - conn io.ReadWriteCloser - cancel context.CancelFunc -} - -// Read reads from connection -func (c *CtxReadWriteCloser) Read(p []byte) (int, error) { - select { - case <-c.ctx.Done(): - return 0, errors.Annotate(c.ctx.Err(), "Read is failed because of closed context") - default: - n, err := c.conn.Read(p) - if err != nil { - c.cancel() - } - return n, err - } -} - -// Write writes into connection. -func (c *CtxReadWriteCloser) Write(p []byte) (int, error) { - select { - case <-c.ctx.Done(): - return 0, errors.Annotate(c.ctx.Err(), "Write is failed because of closed context") - default: - n, err := c.conn.Write(p) - if err != nil { - c.cancel() - } - return n, err - } -} - -// Close closes underlying connection. -func (c *CtxReadWriteCloser) Close() error { - return c.conn.Close() -} - -// NewCtxRWC returns ReadWriteCloser which respects given context, -// cancellation etc. -func NewCtxRWC(ctx context.Context, cancel context.CancelFunc, conn io.ReadWriteCloser) io.ReadWriteCloser { - return &CtxReadWriteCloser{ - conn: conn, - ctx: ctx, - cancel: cancel, - } -} diff --git a/wrappers/logrwc.go b/wrappers/logrwc.go deleted file mode 100644 index 355e2398f..000000000 --- a/wrappers/logrwc.go +++ /dev/null @@ -1,47 +0,0 @@ -package wrappers - -import ( - "io" - - "go.uber.org/zap" -) - -// LogReadWriteCloser adds additional logging for reading/writing. All -// logging is performed for debug mode only. -type LogReadWriteCloser struct { - conn io.ReadWriteCloser - logger *zap.SugaredLogger - sockid string - name string -} - -// Read reads from connection -func (l *LogReadWriteCloser) Read(p []byte) (n int, err error) { - n, err = l.conn.Read(p) - l.logger.Debugw("Finish reading", "name", l.name, "socketid", l.sockid, "nbytes", n, "error", err) - return -} - -// Write writes into connection. -func (l *LogReadWriteCloser) Write(p []byte) (n int, err error) { - n, err = l.conn.Write(p) - l.logger.Debugw("Finish writing", "name", l.name, "socketid", l.sockid, "nbytes", n, "error", err) - return -} - -// Close closes underlying connection. -func (l *LogReadWriteCloser) Close() error { - err := l.conn.Close() - l.logger.Debugw("Finish closing socket", "name", l.name, "socketid", l.sockid, "error", err) - return err -} - -// NewLogRWC wraps ReadWriteCloser with logger calls. -func NewLogRWC(conn io.ReadWriteCloser, logger *zap.SugaredLogger, sockid string, name string) io.ReadWriteCloser { - return &LogReadWriteCloser{ - conn: conn, - logger: logger, - sockid: sockid, - name: name, - } -} diff --git a/wrappers/mtproto_abridged.go b/wrappers/mtproto_abridged.go new file mode 100644 index 000000000..e4775d01f --- /dev/null +++ b/wrappers/mtproto_abridged.go @@ -0,0 +1,159 @@ +package wrappers + +import ( + "bytes" + "io" + "net" + + "github.com/juju/errors" + "go.uber.org/zap" + + "github.com/9seconds/mtg/mtproto" + "github.com/9seconds/mtg/utils" +) + +const ( + mtprotoAbridgedSmallPacketLength = 0x7f + mtprotoAbridgedQuickAckLength = 0x80 + mtprotoAbridgedLargePacketLength = 16777216 // 256 ^ 3 +) + +// MTProtoAbridged presents abridged connection between client and +// middle proxy. +type MTProtoAbridged struct { + conn StreamReadWriteCloser + opts *mtproto.ConnectionOpts + logger *zap.SugaredLogger + + readCounter uint32 + writeCounter uint32 +} + +func (m *MTProtoAbridged) Read() ([]byte, error) { + defer func() { + m.readCounter++ + }() + + m.logger.Debugw("Read packet", + "simple_ack", m.opts.ReadHacks.SimpleAck, + "quick_ack", m.opts.ReadHacks.QuickAck, + "counter", m.readCounter, + ) + + buf := &bytes.Buffer{} + buf.Grow(3) + + if _, err := io.CopyN(buf, m.conn, 1); err != nil { + return nil, errors.Annotate(err, "Cannot read message length") + } + msgLength := uint32(buf.Bytes()[0]) + buf.Reset() + + m.logger.Debugw("Packet first byte", + "byte", msgLength, + "counter", m.readCounter, + "simple_ack", m.opts.ReadHacks.SimpleAck, + "quick_ack", m.opts.ReadHacks.QuickAck, + ) + + if msgLength >= mtprotoAbridgedQuickAckLength { + m.opts.ReadHacks.QuickAck = true + msgLength -= mtprotoAbridgedQuickAckLength + } + + if msgLength == mtprotoAbridgedSmallPacketLength { + if _, err := io.CopyN(buf, m.conn, 3); err != nil { + return nil, errors.Annotate(err, "Cannot read the correct message length") + } + number := utils.Uint24{} + copy(number[:], buf.Bytes()) + msgLength = utils.FromUint24(number) + } + msgLength *= 4 + + m.logger.Debugw("Packet length", + "length", msgLength, + "simple_ack", m.opts.ReadHacks.SimpleAck, + "quick_ack", m.opts.ReadHacks.QuickAck, + "counter", m.readCounter, + ) + + buf.Reset() + buf.Grow(int(msgLength)) + if _, err := io.CopyN(buf, m.conn, int64(msgLength)); err != nil { + return nil, errors.Annotate(err, "Cannot read message") + } + + return buf.Bytes(), nil +} + +func (m *MTProtoAbridged) Write(p []byte) (int, error) { + defer func() { + m.writeCounter++ + }() + + m.logger.Debugw("Write packet", + "length", len(p), + "simple_ack", m.opts.WriteHacks.SimpleAck, + "quick_ack", m.opts.WriteHacks.QuickAck, + "counter", m.writeCounter, + ) + + if len(p)%4 != 0 { + return 0, errors.Errorf("Incorrect packet length %d", len(p)) + } + + if m.opts.WriteHacks.SimpleAck { + return m.conn.Write(utils.ReverseBytes(p)) + } + + packetLength := len(p) / 4 + switch { + case packetLength < mtprotoAbridgedSmallPacketLength: + newData := append([]byte{byte(packetLength)}, p...) + return m.conn.Write(newData) + + case packetLength < mtprotoAbridgedLargePacketLength: + length24 := utils.ToUint24(uint32(packetLength)) + + buf := &bytes.Buffer{} + buf.Grow(1 + 3 + len(p)) + + buf.WriteByte(byte(mtprotoAbridgedSmallPacketLength)) + buf.Write(length24[:]) + buf.Write(p) + + return m.conn.Write(buf.Bytes()) + } + + return 0, errors.Errorf("Packet is too big %d", len(p)) +} + +// Logger returns an instance of the logger for this wrapper. +func (m *MTProtoAbridged) Logger() *zap.SugaredLogger { + return m.logger +} + +// LocalAddr returns local address of the underlying net.Conn. +func (m *MTProtoAbridged) LocalAddr() *net.TCPAddr { + return m.conn.LocalAddr() +} + +// RemoteAddr returns remote address of the underlying net.Conn. +func (m *MTProtoAbridged) RemoteAddr() *net.TCPAddr { + return m.conn.RemoteAddr() +} + +// Close closes underlying net.Conn instance. +func (m *MTProtoAbridged) Close() error { + return m.conn.Close() +} + +// NewMTProtoAbridged creates new wrapper for abridged client connection. +func NewMTProtoAbridged(conn StreamReadWriteCloser, opts *mtproto.ConnectionOpts) PacketReadWriteCloser { + return &MTProtoAbridged{ + conn: conn, + opts: opts, + logger: conn.Logger().Named("mtproto-abridged"), + } +} diff --git a/wrappers/mtproto_cipher.go b/wrappers/mtproto_cipher.go new file mode 100644 index 000000000..bd26eb328 --- /dev/null +++ b/wrappers/mtproto_cipher.go @@ -0,0 +1,96 @@ +package wrappers + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/md5" // nolint: gas + "crypto/sha1" + "encoding/binary" + "net" + + "github.com/9seconds/mtg/mtproto/rpc" + "github.com/9seconds/mtg/utils" +) + +type cipherPurpose uint8 + +const ( + cipherPurposeClient cipherPurpose = iota + cipherPurposeServer +) + +var emptyIP = [4]byte{0x00, 0x00, 0x00, 0x00} + +// NewMiddleProxyCipher creates new block cipher to proxy<->telegram +// connection. +func NewMiddleProxyCipher(conn StreamReadWriteCloser, req *rpc.NonceRequest, resp *rpc.NonceResponse, secret []byte) StreamReadWriteCloser { + localAddr := conn.LocalAddr() + remoteAddr := conn.RemoteAddr() + + encKey, encIV := deriveKeys(cipherPurposeClient, req, resp, localAddr, remoteAddr, secret) + decKey, decIV := deriveKeys(cipherPurposeServer, req, resp, localAddr, remoteAddr, secret) + + enc, _ := makeEncrypterDecrypter(encKey, encIV) + _, dec := makeEncrypterDecrypter(decKey, decIV) + + return NewBlockCipher(conn, enc, dec) +} + +func deriveKeys(purpose cipherPurpose, req *rpc.NonceRequest, resp *rpc.NonceResponse, client *net.TCPAddr, remote *net.TCPAddr, secret []byte) ([]byte, []byte) { + message := bytes.Buffer{} + message.Write(resp.Nonce[:]) + message.Write(req.Nonce[:]) + message.Write(req.CryptoTS[:]) + + clientIPv4 := emptyIP[:] + serverIPv4 := emptyIP[:] + if client.IP.To4() != nil { + clientIPv4 = utils.ReverseBytes(client.IP.To4()) + serverIPv4 = utils.ReverseBytes(remote.IP.To4()) + } + message.Write(serverIPv4) + + var port [2]byte + binary.LittleEndian.PutUint16(port[:], uint16(client.Port)) + message.Write(port[:]) + + switch purpose { + case cipherPurposeClient: + message.WriteString("CLIENT") + case cipherPurposeServer: + message.WriteString("SERVER") + default: + panic("Unexpected cipher purpose") + } + + message.Write(clientIPv4) + binary.LittleEndian.PutUint16(port[:], uint16(remote.Port)) + message.Write(port[:]) + message.Write(secret) + message.Write(resp.Nonce[:]) + + if client.IP.To4() == nil { + message.Write(client.IP.To16()) + message.Write(remote.IP.To16()) + } + message.Write(req.Nonce[:]) + + data := message.Bytes() + md5sum := md5.Sum(data[1:]) // nolint: gas + sha1sum := sha1.Sum(data) + + key := append(md5sum[:12], sha1sum[:]...) + iv := md5.Sum(data[2:]) // nolint: gas + + return key, iv[:] +} + +func makeEncrypterDecrypter(key, iv []byte) (cipher.BlockMode, cipher.BlockMode) { + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + return cipher.NewCBCEncrypter(block, iv), cipher.NewCBCDecrypter(block, iv) +} diff --git a/wrappers/mtproto_frame.go b/wrappers/mtproto_frame.go new file mode 100644 index 000000000..58a523ceb --- /dev/null +++ b/wrappers/mtproto_frame.go @@ -0,0 +1,160 @@ +package wrappers + +import ( + "bytes" + "crypto/aes" + "encoding/binary" + "hash/crc32" + "io" + "io/ioutil" + "net" + + "github.com/juju/errors" + "go.uber.org/zap" +) + +const ( + mtprotoFrameMinMessageLength = 12 + mtprotoFrameMaxMessageLength = 16777216 +) + +var mtprotoFramePadding = []byte{0x04, 0x00, 0x00, 0x00} + +// MTProtoFrame is a wrapper which converts written data to the MTProtoFrame. +// The format of the frame: +// +// [ MSGLEN(4) | SEQNO(4) | MSG(...) | CRC32(4) | PADDING(4*x) ] +// +// MSGLEN is the length of the message + len of seqno and msglen. +// SEQNO is the number of frame in the receive/send sequence. If client +// sends a message with SeqNo 18, it has to receive message with SeqNo 18. +// MSG is the data which has to be written +// CRC32 is the CRC32 checksum of MSGLEN + SEQNO + MSG +// PADDING is custom padding schema to complete frame length to such that +// len(frame) % 16 == 0 +type MTProtoFrame struct { + conn StreamReadWriteCloser + logger *zap.SugaredLogger + + readSeqNo int32 + writeSeqNo int32 +} + +func (m *MTProtoFrame) Read() ([]byte, error) { // nolint: gocyclo + buf := &bytes.Buffer{} + sum := crc32.NewIEEE() + writer := io.MultiWriter(buf, sum) + + for { + buf.Reset() + sum.Reset() + if _, err := io.CopyN(writer, m.conn, 4); err != nil { + return nil, errors.Annotate(err, "Cannot read frame padding") + } + if !bytes.Equal(buf.Bytes(), mtprotoFramePadding) { + break + } + } + + messageLength := binary.LittleEndian.Uint32(buf.Bytes()) + m.logger.Debugw("Read MTProto frame", + "messageLength", messageLength, + "sequence_number", m.readSeqNo, + ) + if messageLength%4 != 0 || messageLength < mtprotoFrameMinMessageLength || messageLength > mtprotoFrameMaxMessageLength { + return nil, errors.Errorf("Incorrect frame message length %d", messageLength) + } + + buf.Reset() + buf.Grow(int(messageLength) - 4 - 4) + if _, err := io.CopyN(writer, m.conn, int64(messageLength)-4-4); err != nil { + return nil, errors.Annotate(err, "Cannot read the message frame") + } + + var seqNo int32 + binary.Read(buf, binary.LittleEndian, &seqNo) // nolint: errcheck + if seqNo != m.readSeqNo { + return nil, errors.Errorf("Unexpected sequence number %d (wait for %d)", seqNo, m.readSeqNo) + } + + data, _ := ioutil.ReadAll(buf) + buf.Reset() + // write to buf, not to writer. This is because we are going to fetch + // crc32 checksum. + if _, err := io.CopyN(buf, m.conn, 4); err != nil { + return nil, errors.Annotate(err, "Cannot read checksum") + } + + checksum := binary.LittleEndian.Uint32(buf.Bytes()) + if checksum != sum.Sum32() { + return nil, errors.Errorf("CRC32 checksum mismatch. Wait for %d, got %d", sum.Sum32(), checksum) + } + + m.logger.Debugw("Read MTProto frame", + "messageLength", messageLength, + "sequence_number", m.readSeqNo, + "dataLength", len(data), + "checksum", checksum, + ) + m.readSeqNo++ + + return data, nil +} + +func (m *MTProtoFrame) Write(p []byte) (int, error) { + messageLength := 4 + 4 + len(p) + 4 + paddingLength := (aes.BlockSize - messageLength%aes.BlockSize) % aes.BlockSize + + buf := &bytes.Buffer{} + buf.Grow(messageLength + paddingLength) + + binary.Write(buf, binary.LittleEndian, uint32(messageLength)) // nolint: errcheck + binary.Write(buf, binary.LittleEndian, m.writeSeqNo) // nolint: errcheck + buf.Write(p) + + checksum := crc32.ChecksumIEEE(buf.Bytes()) + binary.Write(buf, binary.LittleEndian, checksum) // nolint: errcheck + buf.Write(bytes.Repeat(mtprotoFramePadding, paddingLength/4)) + + m.logger.Debugw("Write MTProto frame", + "length", len(p), + "sequence_number", m.writeSeqNo, + "crc32", checksum, + "frame_length", buf.Len(), + ) + m.writeSeqNo++ + + _, err := m.conn.Write(buf.Bytes()) + + return len(p), err +} + +// Logger returns an instance of the logger for this wrapper. +func (m *MTProtoFrame) Logger() *zap.SugaredLogger { + return m.logger +} + +// LocalAddr returns local address of the underlying net.Conn. +func (m *MTProtoFrame) LocalAddr() *net.TCPAddr { + return m.conn.LocalAddr() +} + +// RemoteAddr returns remote address of the underlying net.Conn. +func (m *MTProtoFrame) RemoteAddr() *net.TCPAddr { + return m.conn.RemoteAddr() +} + +// Close closes underlying net.Conn instance. +func (m *MTProtoFrame) Close() error { + return m.conn.Close() +} + +// NewMTProtoFrame creates new PacketWrapper for underlying connection. +func NewMTProtoFrame(conn StreamReadWriteCloser, seqNo int32) PacketReadWriteCloser { + return &MTProtoFrame{ + conn: conn, + logger: conn.Logger().Named("mtproto-frame"), + readSeqNo: seqNo, + writeSeqNo: seqNo, + } +} diff --git a/wrappers/mtproto_intermediate.go b/wrappers/mtproto_intermediate.go new file mode 100644 index 000000000..ab0761c5b --- /dev/null +++ b/wrappers/mtproto_intermediate.go @@ -0,0 +1,121 @@ +package wrappers + +import ( + "bytes" + "encoding/binary" + "io" + "net" + + "github.com/juju/errors" + "go.uber.org/zap" + + "github.com/9seconds/mtg/mtproto" +) + +const mtprotoIntermediateQuickAckLength = 0x80000000 + +// MTProtoIntermediate presents intermediate connection between client +// and Telegram. +type MTProtoIntermediate struct { + conn StreamReadWriteCloser + opts *mtproto.ConnectionOpts + logger *zap.SugaredLogger + + readCounter uint32 + writeCounter uint32 +} + +func (m *MTProtoIntermediate) Read() ([]byte, error) { + defer func() { + m.readCounter++ + }() + + m.logger.Debugw("Read packet", + "simple_ack", m.opts.ReadHacks.SimpleAck, + "quick_ack", m.opts.ReadHacks.QuickAck, + "counter", m.readCounter, + ) + + buf := &bytes.Buffer{} + buf.Grow(4) + + if _, err := io.CopyN(buf, m.conn, 4); err != nil { + return nil, errors.Annotate(err, "Cannot read message length") + } + length := binary.LittleEndian.Uint32(buf.Bytes()) + + m.logger.Debugw("Packet message length", + "simple_ack", m.opts.ReadHacks.SimpleAck, + "quick_ack", m.opts.ReadHacks.QuickAck, + "counter", m.readCounter, + "length", length, + ) + + if length > mtprotoIntermediateQuickAckLength { + m.opts.ReadHacks.QuickAck = true + length -= mtprotoIntermediateQuickAckLength + } + + buf.Reset() + buf.Grow(int(length)) + if _, err := io.CopyN(buf, m.conn, int64(length)); err != nil { + return nil, errors.Annotate(err, "Cannot read the message") + } + + if length%4 != 0 { + length -= length % 4 + } + + return buf.Bytes()[:length], nil +} + +func (m *MTProtoIntermediate) Write(p []byte) (int, error) { + defer func() { + m.writeCounter++ + }() + + m.logger.Debugw("Write packet", + "simple_ack", m.opts.WriteHacks.SimpleAck, + "quick_ack", m.opts.WriteHacks.QuickAck, + "counter", m.writeCounter, + ) + + if m.opts.ReadHacks.SimpleAck { + return m.conn.Write(p) + } + + var length [4]byte + binary.LittleEndian.PutUint32(length[:], uint32(len(p))) + + return m.conn.Write(append(length[:], p...)) +} + +// Logger returns an instance of the logger for this wrapper. +func (m *MTProtoIntermediate) Logger() *zap.SugaredLogger { + return m.logger +} + +// LocalAddr returns local address of the underlying net.Conn. +func (m *MTProtoIntermediate) LocalAddr() *net.TCPAddr { + return m.conn.LocalAddr() +} + +// RemoteAddr returns remote address of the underlying net.Conn. +func (m *MTProtoIntermediate) RemoteAddr() *net.TCPAddr { + return m.conn.RemoteAddr() +} + +// Close closes underlying net.Conn instance. +func (m *MTProtoIntermediate) Close() error { + return m.conn.Close() +} + +// NewMTProtoIntermediate creates new PacketWrapper for intermediate +// client connection. +func NewMTProtoIntermediate(conn StreamReadWriteCloser, opts *mtproto.ConnectionOpts) PacketReadWriteCloser { + return &MTProtoIntermediate{ + conn: conn, + logger: conn.Logger().Named("mtproto-intermediate"), + opts: opts, + } +} diff --git a/wrappers/mtproto_proxy.go b/wrappers/mtproto_proxy.go new file mode 100644 index 000000000..1babb95b4 --- /dev/null +++ b/wrappers/mtproto_proxy.go @@ -0,0 +1,164 @@ +package wrappers + +import ( + "bytes" + "fmt" + "net" + + "github.com/juju/errors" + "go.uber.org/zap" + + "github.com/9seconds/mtg/mtproto" + "github.com/9seconds/mtg/mtproto/rpc" +) + +// MTProtoProxy is a wrapper which creates/reads RPC responses from Telegram. +type MTProtoProxy struct { + conn PacketReadWriteCloser + req *rpc.ProxyRequest + logger *zap.SugaredLogger + + readCounter uint32 + writeCounter uint32 +} + +func (m *MTProtoProxy) Read() ([]byte, error) { + defer func() { + m.readCounter++ + }() + + m.logger.Debugw("Read packet", + "counter", m.readCounter, + "simple_ack", m.req.Options.WriteHacks.SimpleAck, + "quick_ack", m.req.Options.WriteHacks.QuickAck, + ) + + packet, err := m.conn.Read() + if err != nil { + return nil, errors.Annotate(err, "Cannot read packet") + } + + m.logger.Debugw("Read packet length", + "counter", m.readCounter, + "simple_ack", m.req.Options.WriteHacks.SimpleAck, + "quick_ack", m.req.Options.WriteHacks.QuickAck, + "length", len(packet), + ) + + if len(packet) < 4 { + return nil, errors.Annotate(err, "Incorrect packet length") + } + + tag, packet := packet[:4], packet[4:] + switch { + case bytes.Equal(tag, rpc.TagProxyAns): + return m.readProxyAns(packet) + case bytes.Equal(tag, rpc.TagSimpleAck): + return m.readSimpleAck(packet) + case bytes.Equal(tag, rpc.TagCloseExt): + return m.readCloseExt(packet) + } + + return nil, errors.Errorf("Unknown RPC answer %v", tag) +} + +func (m *MTProtoProxy) readProxyAns(data []byte) ([]byte, error) { + if len(data) < 12 { + return nil, errors.Errorf("Incorrect data of proxy answer: %d", len(data)) + } + data = data[12:] + + m.logger.Debugw("Read RPC_PROXY_ANS", + "counter", m.readCounter, + "length", len(data), + ) + + return data, nil +} + +func (m *MTProtoProxy) readSimpleAck(data []byte) ([]byte, error) { + if len(data) != 12 { + return nil, errors.Errorf("Incorrect data of simple ack: %d", len(data)) + } + data = data[8:12] + m.req.Options.WriteHacks.SimpleAck = true + + m.logger.Debugw("Read RPC_SIMPLE_ACK", + "counter", m.readCounter, + "length", len(data), + ) + + return data, nil +} + +func (m *MTProtoProxy) readCloseExt(data []byte) ([]byte, error) { + m.logger.Debugw("Read RPC_CLOSE_EXT", "counter", m.readCounter) + + return nil, errors.New("Connection has been closed remotely by RPC call") +} + +func (m *MTProtoProxy) Write(p []byte) (int, error) { + defer func() { + m.writeCounter++ + }() + + m.logger.Debugw("Write packet", + "length", len(p), + "counter", m.writeCounter, + "simple_ack", m.req.Options.ReadHacks.SimpleAck, + "quick_ack", m.req.Options.ReadHacks.QuickAck, + ) + + header, flags := m.req.MakeHeader(p) + if ce := m.logger.Desugar().Check(zap.DebugLevel, "RPC_PROXY_REQ header"); ce != nil { + ce.Write( + zap.Int("length", len(p)), + zap.Uint32("counter", m.writeCounter), + zap.Bool("simple_ack", m.req.Options.ReadHacks.QuickAck), + zap.Bool("quick_ack", m.req.Options.ReadHacks.SimpleAck), + zap.String("header", fmt.Sprintf("%v", header.Bytes())), + zap.Stringer("flags", flags), + ) + } + header.Write(p) + + if _, err := m.conn.Write(header.Bytes()); err != nil { + return 0, err + } + + return len(p), nil +} + +// Logger returns an instance of the logger for this wrapper. +func (m *MTProtoProxy) Logger() *zap.SugaredLogger { + return m.logger +} + +// LocalAddr returns local address of the underlying net.Conn. +func (m *MTProtoProxy) LocalAddr() *net.TCPAddr { + return m.conn.LocalAddr() +} + +// RemoteAddr returns remote address of the underlying net.Conn. +func (m *MTProtoProxy) RemoteAddr() *net.TCPAddr { + return m.conn.RemoteAddr() +} + +// Close closes underlying net.Conn instance. +func (m *MTProtoProxy) Close() error { + return m.conn.Close() +} + +// NewMTProtoProxy creates new RPC wrapper. +func NewMTProtoProxy(conn PacketReadWriteCloser, connOpts *mtproto.ConnectionOpts, adTag []byte) (PacketReadWriteCloser, error) { + req, err := rpc.NewProxyRequest(connOpts.ClientAddr, conn.LocalAddr(), connOpts, adTag) + if err != nil { + return nil, errors.Annotate(err, "Cannot create new RPC proxy request") + } + + return &MTProtoProxy{ + conn: conn, + logger: conn.Logger().Named("mtproto-proxy"), + req: req, + }, nil +} diff --git a/wrappers/streamcipher.go b/wrappers/streamcipher.go new file mode 100644 index 000000000..b50a351c6 --- /dev/null +++ b/wrappers/streamcipher.go @@ -0,0 +1,65 @@ +package wrappers + +import ( + "crypto/cipher" + "net" + + "github.com/juju/errors" + "go.uber.org/zap" +) + +// StreamCipher is a wrapper which encrypts/decrypts stream with AES-CTR +// (as a part of obfuscated2 protocol). +type StreamCipher struct { + encryptor cipher.Stream + decryptor cipher.Stream + conn StreamReadWriteCloser + logger *zap.SugaredLogger +} + +func (s *StreamCipher) Read(p []byte) (int, error) { + n, err := s.conn.Read(p) + if err != nil { + return 0, errors.Annotate(err, "Cannot read stream ciphered data") + } + s.decryptor.XORKeyStream(p, p[:n]) + + return n, nil +} + +func (s *StreamCipher) Write(p []byte) (int, error) { + encrypted := make([]byte, len(p)) + s.encryptor.XORKeyStream(encrypted, p) + + return s.conn.Write(encrypted) +} + +// Logger returns an instance of the logger for this wrapper. +func (s *StreamCipher) Logger() *zap.SugaredLogger { + return s.logger +} + +// LocalAddr returns local address of the underlying net.Conn. +func (s *StreamCipher) LocalAddr() *net.TCPAddr { + return s.conn.LocalAddr() +} + +// RemoteAddr returns remote address of the underlying net.Conn. +func (s *StreamCipher) RemoteAddr() *net.TCPAddr { + return s.conn.RemoteAddr() +} + +// Close closes underlying net.Conn instance. +func (s *StreamCipher) Close() error { + return s.conn.Close() +} + +// NewStreamCipher creates new stream cipher wrapper. +func NewStreamCipher(conn StreamReadWriteCloser, encryptor, decryptor cipher.Stream) StreamReadWriteCloser { + return &StreamCipher{ + conn: conn, + logger: conn.Logger().Named("stream-cipher"), + encryptor: encryptor, + decryptor: decryptor, + } +} diff --git a/wrappers/streamcipherrwc.go b/wrappers/streamcipherrwc.go deleted file mode 100644 index 1d7d73c96..000000000 --- a/wrappers/streamcipherrwc.go +++ /dev/null @@ -1,54 +0,0 @@ -package wrappers - -import ( - "crypto/cipher" - "io" -) - -// StreamCipherReadWriteCloser is a ReadWriteCloser which ciphers -// incoming and outgoing data with givem cipher.Stream instances. -type StreamCipherReadWriteCloser struct { - encryptor cipher.Stream - decryptor cipher.Stream - conn io.ReadWriteCloser -} - -// Read reads from connection -func (c *StreamCipherReadWriteCloser) Read(p []byte) (n int, err error) { - n, err = c.conn.Read(p) - c.decryptor.XORKeyStream(p, p[:n]) - return -} - -// Write writes into connection. -func (c *StreamCipherReadWriteCloser) Write(p []byte) (int, error) { - // This is to decrease an amount of allocations. Unfortunately, escape - // analysis in (at least Golang 1.10) is absolutely not perfect. For - // example, it understands that we want to have a slice locally, right? - // But since slice is effectively 2 ints + uintptr to [number]byte, the - // most heavyweight part is placed in heap. - buf := getBuffer() - defer putBuffer(buf) - buf.Grow(len(p)) - buf.Write(p) - - encrypted := buf.Bytes() - c.encryptor.XORKeyStream(encrypted, p) - - return c.conn.Write(encrypted) -} - -// Close closes underlying connection. -func (c *StreamCipherReadWriteCloser) Close() error { - return c.conn.Close() -} - -// NewStreamCipherRWC returns wrapper which transparently -// encrypts/decrypts traffic with obfuscated2 protocol. -func NewStreamCipherRWC(conn io.ReadWriteCloser, encryptor, decryptor cipher.Stream) io.ReadWriteCloser { - return &StreamCipherReadWriteCloser{ - conn: conn, - encryptor: encryptor, - decryptor: decryptor, - } -} diff --git a/wrappers/trafficrwc.go b/wrappers/trafficrwc.go deleted file mode 100644 index 485a54c84..000000000 --- a/wrappers/trafficrwc.go +++ /dev/null @@ -1,39 +0,0 @@ -package wrappers - -import "io" - -// TrafficReadWriteCloser counts an amount of ingress/egress traffic by -// calling given callbacks. -type TrafficReadWriteCloser struct { - conn io.ReadWriteCloser - readCallback func(int) - writeCallback func(int) -} - -// Read reads from connection -func (t *TrafficReadWriteCloser) Read(p []byte) (n int, err error) { - n, err = t.conn.Read(p) - t.readCallback(n) - return -} - -// Write writes into connection. -func (t *TrafficReadWriteCloser) Write(p []byte) (n int, err error) { - n, err = t.conn.Write(p) - t.writeCallback(n) - return -} - -// Close closes underlying connection. -func (t *TrafficReadWriteCloser) Close() error { - return t.conn.Close() -} - -// NewTrafficRWC wraps ReadWriteCloser to have read/write callbacks. -func NewTrafficRWC(conn io.ReadWriteCloser, readCallback, writeCallback func(int)) io.ReadWriteCloser { - return &TrafficReadWriteCloser{ - conn: conn, - readCallback: readCallback, - writeCallback: writeCallback, - } -} diff --git a/wrappers/wrap.go b/wrappers/wrap.go new file mode 100644 index 000000000..5e9581b04 --- /dev/null +++ b/wrappers/wrap.go @@ -0,0 +1,111 @@ +package wrappers + +import ( + "io" + "net" + + "go.uber.org/zap" +) + +// Wrap is a base interface for all wrappers in this package. +type Wrap interface { + Logger() *zap.SugaredLogger + LocalAddr() *net.TCPAddr + RemoteAddr() *net.TCPAddr +} + +// Writer is a base interface for writers of this package. +type Writer interface { + io.Writer + Wrap +} + +// Closer is a base interface for wrappers of this package which can +// close connections. +type Closer interface { + io.Closer + Wrap +} + +// WriteCloser is a base interface for wrappers of this package which +// can write to and close connections. +type WriteCloser interface { + io.Closer + Writer +} + +// StreamReader is a base interface for wrappers which can read from the +// stream. +type StreamReader interface { + io.Reader + Wrap +} + +// StreamReadCloser is a base interface for wrappers which can read from +// and close the connections. +type StreamReadCloser interface { + io.Closer + StreamReader +} + +// StreamReadWriter is a base interface for wrappers which can read from +// and write to the connections. +type StreamReadWriter interface { + io.Writer + StreamReader +} + +// StreamWriteCloser is a base interface for wrappers which can write to +// and close the connections. +type StreamWriteCloser interface { + io.WriteCloser + Wrap +} + +// StreamReadWriteCloser is a base interface for stream processors. +type StreamReadWriteCloser interface { + io.Closer + StreamReadWriter +} + +// PacketReader is a base interface for wrappers which reads 'packets'. +// packets are atoms so you either get a packet or you get an error You +// cannot resume reading from packet. +type PacketReader interface { + Read() ([]byte, error) + Wrap +} + +// PacketWriter is a base interface for wrappers which can write packets. +type PacketWriter interface { + io.Writer + Wrap +} + +// PacketReadWriter is a base interface for wrappers which can read from +// and write packets. +type PacketReadWriter interface { + io.Writer + PacketReader +} + +// PacketReadCloser is a base interface for wrappers which can read +// packets and close the connection. +type PacketReadCloser interface { + io.Closer + PacketReader +} + +// PacketWriteCloser is a base interface for wrappers which can write +// packets and close the connection. +type PacketWriteCloser interface { + io.Writer + io.Closer + Wrap +} + +// PacketReadWriteCloser is a base interface for packet processors. +type PacketReadWriteCloser interface { + io.Closer + PacketReadWriter +}