diff --git a/callbacks.go b/callbacks.go index f9ba131b13..0c66bc50f7 100644 --- a/callbacks.go +++ b/callbacks.go @@ -11,10 +11,14 @@ import ( type Callbacks struct { // Called after a peer connection completes the BitTorrent handshake. The Client lock is not // held. - CompletedHandshake func(*PeerConn, InfoHash) - ReadMessage func(*PeerConn, *pp.Message) + CompletedHandshake func(*PeerConn, InfoHash) + ReadMessage func(*PeerConn, *pp.Message) + // This can be folded into the general case below. ReadExtendedHandshake func(*PeerConn, *pp.ExtendedHandshakeMessage) PeerConnClosed func(*PeerConn) + // BEP 10 message. Not sure if I should call this Ltep universally. Each handler here is called + // in order. + PeerConnReadExtensionMessage []func(PeerConnReadExtensionMessageEvent) // Provides secret keys to be tried against incoming encrypted connections. ReceiveEncryptedHandshakeSkeys mse.SecretKeyIter @@ -25,6 +29,11 @@ type Callbacks struct { SentRequest []func(PeerRequestEvent) PeerClosed []func(*Peer) NewPeer []func(*Peer) + // Called when a PeerConn has been added to a Torrent. It's finished all BitTorrent protocol + // handshakes, and is about to start sending and receiving BitTorrent messages. The extended + // handshake has not yet occurred. This is a good time to alter the supported extension + // protocols. + PeerConnAdded []func(*PeerConn) } type ReceivedUsefulDataEvent = PeerMessageEvent @@ -38,3 +47,10 @@ type PeerRequestEvent struct { Peer *Peer Request } + +type PeerConnReadExtensionMessageEvent struct { + PeerConn *PeerConn + // You can look up what protocol this corresponds to using the PeerConn.LocalLtepProtocolMap. + ExtensionNumber pp.ExtensionNumber + Payload []byte +} diff --git a/client.go b/client.go index 62c0d2b4c5..f2f7f77231 100644 --- a/client.go +++ b/client.go @@ -43,7 +43,6 @@ import ( "github.com/anacrolix/torrent/metainfo" "github.com/anacrolix/torrent/mse" pp "github.com/anacrolix/torrent/peer_protocol" - utHolepunch "github.com/anacrolix/torrent/peer_protocol/ut-holepunch" request_strategy "github.com/anacrolix/torrent/request-strategy" "github.com/anacrolix/torrent/storage" "github.com/anacrolix/torrent/tracker" @@ -1078,6 +1077,10 @@ func (t *Torrent) runHandshookConn(pc *PeerConn) error { return fmt.Errorf("adding connection: %w", err) } defer t.dropConnection(pc) + pc.addBuiltinLtepProtocols(!cl.config.DisablePEX) + for _, cb := range pc.callbacks.PeerConnAdded { + cb(pc) + } pc.startMessageWriter() pc.sendInitialMessages() pc.initUpdateRequestsTimer() @@ -1135,10 +1138,6 @@ func (pc *PeerConn) sendInitialMessages() { ExtendedID: pp.HandshakeExtendedID, ExtendedPayload: func() []byte { msg := pp.ExtendedHandshakeMessage{ - M: map[pp.ExtensionName]pp.ExtensionNumber{ - pp.ExtensionNameMetadata: metadataExtendedId, - utHolepunch.ExtensionName: utHolepunchExtendedId, - }, V: cl.config.ExtendedHandshakeClientVersion, Reqq: localClientReqq, YourIp: pp.CompactIp(pc.remoteIp()), @@ -1149,8 +1148,12 @@ func (pc *PeerConn) sendInitialMessages() { Ipv4: pp.CompactIp(cl.config.PublicIp4.To4()), Ipv6: cl.config.PublicIp6.To16(), } - if !cl.config.DisablePEX { - msg.M[pp.ExtensionNamePex] = pexExtendedId + g.MakeMapWithCap(&msg.M, len(pc.LocalLtepProtocolMap)) + for i, name := range pc.LocalLtepProtocolMap { + old := g.MapInsert(msg.M, name, pp.ExtensionNumber(i+1)) + if old.Ok { + panic(fmt.Sprintf("extension %q already defined with id %v", name, old.Value)) + } } return bencode.MustMarshal(msg) }(), diff --git a/global.go b/global.go index 5a5bddba0a..585bbeafaa 100644 --- a/global.go +++ b/global.go @@ -19,14 +19,6 @@ const ( maxMetadataSize uint32 = 16 * 1024 * 1024 ) -// These are our extended message IDs. Peers will use these values to -// select which extension a message is intended for. -const ( - metadataExtendedId = iota + 1 // 0 is reserved for deleting keys - pexExtendedId - utHolepunchExtendedId -) - func defaultPeerExtensionBytes() PeerExtensionBits { return pp.NewPeerExtensionBytes(pp.ExtensionBitDht, pp.ExtensionBitLtep, pp.ExtensionBitFast) } diff --git a/peer_protocol/extended.go b/peer_protocol/extended.go index 8bc5181633..019590e40a 100644 --- a/peer_protocol/extended.go +++ b/peer_protocol/extended.go @@ -24,7 +24,7 @@ type ( } ExtensionName string - ExtensionNumber int + ExtensionNumber uint8 ) const ( diff --git a/peerconn.go b/peerconn.go index 0dd99cef1f..d4fef6ef7f 100644 --- a/peerconn.go +++ b/peerconn.go @@ -44,6 +44,14 @@ type PeerConn struct { PeerID PeerID PeerExtensionBytes pp.PeerExtensionBits PeerListenPort int + // 1-based mapping from extension number to extension name (subtract one from the extension ID + // to find the corresponding protocol name). The first LocalLtepProtocolBuiltinCount of these + // are use builtin handlers. If you want to handle builtin protocols yourself, you would move + // them above the threshold. You can disable them by removing them entirely, and add your own. + // These changes should be done in the PeerConnAdded callback. + LocalLtepProtocolMap []pp.ExtensionName + // How many of the protocols are using the builtin handlers. + LocalLtepProtocolBuiltinCount int // The actual Conn, used for closing, and setting socket options. Do not use methods on this // while holding any mutexes. @@ -55,6 +63,7 @@ type PeerConn struct { messageWriter peerConnMsgWriter + // The peer's extension map, as sent in their extended handshake. PeerExtensionIDs map[pp.ExtensionName]pp.ExtensionNumber PeerClientName atomic.Value uploadTimer *time.Timer @@ -869,8 +878,17 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err }() t := c.t cl := t.cl - switch id { - case pp.HandshakeExtendedID: + { + event := PeerConnReadExtensionMessageEvent{ + PeerConn: c, + ExtensionNumber: id, + Payload: payload, + } + for _, cb := range c.callbacks.PeerConnReadExtensionMessage { + cb(event) + } + } + if id == pp.HandshakeExtendedID { var d pp.ExtendedHandshakeMessage if err := bencode.Unmarshal(payload, &d); err != nil { c.logger.Printf("error parsing extended handshake message %q: %s", payload, err) @@ -879,7 +897,6 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err if cb := c.callbacks.ReadExtendedHandshake; cb != nil { cb(c, &d) } - // c.logger.WithDefaultLevel(log.Debug).Printf("received extended handshake message:\n%s", spew.Sdump(d)) if d.Reqq != 0 { c.PeerMaxRequests = d.Reqq } @@ -911,13 +928,25 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err c.pex.Init(c) } return nil - case metadataExtendedId: + } + // Zero was taken care of above. + protocolIndex := int(id - 1) + if protocolIndex >= len(c.LocalLtepProtocolMap) { + return fmt.Errorf("unexpected extended message ID: %v", id) + } + if protocolIndex >= c.LocalLtepProtocolBuiltinCount { + // The message should have been handled by the PeerConnReadExtensionMessage callback. + return nil + } + extensionName := c.LocalLtepProtocolMap[protocolIndex] + switch extensionName { + case pp.ExtensionNameMetadata: err := cl.gotMetadataExtensionMsg(payload, t, c) if err != nil { return fmt.Errorf("handling metadata extension message: %w", err) } return nil - case pexExtendedId: + case pp.ExtensionNamePex: if !c.pex.IsEnabled() { return nil // or hang-up maybe? } @@ -926,7 +955,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err err = fmt.Errorf("receiving pex message: %w", err) } return - case utHolepunchExtendedId: + case utHolepunch.ExtensionName: var msg utHolepunch.Msg err = msg.UnmarshalBinary(payload) if err != nil { @@ -936,7 +965,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err err = c.t.handleReceivedUtHolepunchMsg(msg, c) return default: - return fmt.Errorf("unexpected extended message ID: %v", id) + panic(fmt.Sprintf("unhandled builtin extension protocol %q", extensionName)) } } @@ -1144,3 +1173,15 @@ func (c *PeerConn) useful() bool { } return false } + +func (c *PeerConn) addBuiltinLtepProtocols(pex bool) { + ps := []pp.ExtensionName{pp.ExtensionNameMetadata, utHolepunch.ExtensionName} + if pex { + ps = append(ps, pp.ExtensionNamePex) + } + if c.LocalLtepProtocolMap != nil { + panic("already set") + } + c.LocalLtepProtocolMap = ps + c.LocalLtepProtocolBuiltinCount = len(ps) +} diff --git a/pexconn_test.go b/pexconn_test.go index f8b9c9e07a..b8be73e887 100644 --- a/pexconn_test.go +++ b/pexconn_test.go @@ -22,7 +22,7 @@ func TestPexConnState(t *testing.T) { network: addr.Network(), }) c.PeerExtensionIDs = make(map[pp.ExtensionName]pp.ExtensionNumber) - c.PeerExtensionIDs[pp.ExtensionNamePex] = pexExtendedId + c.PeerExtensionIDs[pp.ExtensionNamePex] = 1 c.messageWriter.mu.Lock() c.setTorrent(torrent) if err := torrent.addPeerConn(c); err != nil { @@ -45,7 +45,8 @@ func TestPexConnState(t *testing.T) { c.pex.Share(testWriter) require.True(t, writerCalled) require.EqualValues(t, pp.Extended, out.Type) - require.EqualValues(t, pexExtendedId, out.ExtendedID) + require.NotEqualValues(t, out.ExtendedID, 0) + require.EqualValues(t, c.PeerExtensionIDs[pp.ExtensionNamePex], out.ExtendedID) x, err := pp.LoadPexMsg(out.ExtendedPayload) require.NoError(t, err)