Skip to content

Commit

Permalink
Start implementing user BEP 10 extension support
Browse files Browse the repository at this point in the history
  • Loading branch information
anacrolix committed Jan 15, 2024
1 parent 867996b commit 7f85744
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 27 deletions.
20 changes: 18 additions & 2 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ import (
type Callbacks struct {
// Called after a peer connection completes the BitTorrent handshake. The Client lock is not
// held.
CompletedHandshake func(*PeerConn, InfoHash)
ReadMessage func(*PeerConn, *pp.Message)
CompletedHandshake func(*PeerConn, InfoHash)
ReadMessage func(*PeerConn, *pp.Message)
// This can be folded into the general case below.
ReadExtendedHandshake func(*PeerConn, *pp.ExtendedHandshakeMessage)
PeerConnClosed func(*PeerConn)
// BEP 10 message. Not sure if I should call this Ltep universally. Each handler here is called
// in order.
PeerConnReadExtensionMessage []func(PeerConnReadExtensionMessageEvent)

// Provides secret keys to be tried against incoming encrypted connections.
ReceiveEncryptedHandshakeSkeys mse.SecretKeyIter
Expand All @@ -25,6 +29,11 @@ type Callbacks struct {
SentRequest []func(PeerRequestEvent)
PeerClosed []func(*Peer)
NewPeer []func(*Peer)
// Called when a PeerConn has been added to a Torrent. It's finished all BitTorrent protocol
// handshakes, and is about to start sending and receiving BitTorrent messages. The extended
// handshake has not yet occurred. This is a good time to alter the supported extension
// protocols.
PeerConnAdded []func(*PeerConn)
}

type ReceivedUsefulDataEvent = PeerMessageEvent
Expand All @@ -38,3 +47,10 @@ type PeerRequestEvent struct {
Peer *Peer
Request
}

type PeerConnReadExtensionMessageEvent struct {
PeerConn *PeerConn
// You can look up what protocol this corresponds to using the PeerConn.LocalLtepProtocolMap.
ExtensionNumber pp.ExtensionNumber
Payload []byte
}
17 changes: 10 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import (
"github.com/anacrolix/torrent/metainfo"
"github.com/anacrolix/torrent/mse"
pp "github.com/anacrolix/torrent/peer_protocol"
utHolepunch "github.com/anacrolix/torrent/peer_protocol/ut-holepunch"
request_strategy "github.com/anacrolix/torrent/request-strategy"
"github.com/anacrolix/torrent/storage"
"github.com/anacrolix/torrent/tracker"
Expand Down Expand Up @@ -1078,6 +1077,10 @@ func (t *Torrent) runHandshookConn(pc *PeerConn) error {
return fmt.Errorf("adding connection: %w", err)
}
defer t.dropConnection(pc)
pc.addBuiltinLtepProtocols(!cl.config.DisablePEX)
for _, cb := range pc.callbacks.PeerConnAdded {
cb(pc)
}
pc.startMessageWriter()
pc.sendInitialMessages()
pc.initUpdateRequestsTimer()
Expand Down Expand Up @@ -1135,10 +1138,6 @@ func (pc *PeerConn) sendInitialMessages() {
ExtendedID: pp.HandshakeExtendedID,
ExtendedPayload: func() []byte {
msg := pp.ExtendedHandshakeMessage{
M: map[pp.ExtensionName]pp.ExtensionNumber{
pp.ExtensionNameMetadata: metadataExtendedId,
utHolepunch.ExtensionName: utHolepunchExtendedId,
},
V: cl.config.ExtendedHandshakeClientVersion,
Reqq: localClientReqq,
YourIp: pp.CompactIp(pc.remoteIp()),
Expand All @@ -1149,8 +1148,12 @@ func (pc *PeerConn) sendInitialMessages() {
Ipv4: pp.CompactIp(cl.config.PublicIp4.To4()),
Ipv6: cl.config.PublicIp6.To16(),
}
if !cl.config.DisablePEX {
msg.M[pp.ExtensionNamePex] = pexExtendedId
g.MakeMapWithCap(&msg.M, len(pc.LocalLtepProtocolMap))
for i, name := range pc.LocalLtepProtocolMap {
old := g.MapInsert(msg.M, name, pp.ExtensionNumber(i+1))
if old.Ok {
panic(fmt.Sprintf("extension %q already defined with id %v", name, old.Value))
}
}
return bencode.MustMarshal(msg)
}(),
Expand Down
8 changes: 0 additions & 8 deletions global.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@ const (
maxMetadataSize uint32 = 16 * 1024 * 1024
)

// These are our extended message IDs. Peers will use these values to
// select which extension a message is intended for.
const (
metadataExtendedId = iota + 1 // 0 is reserved for deleting keys
pexExtendedId
utHolepunchExtendedId
)

func defaultPeerExtensionBytes() PeerExtensionBits {
return pp.NewPeerExtensionBytes(pp.ExtensionBitDht, pp.ExtensionBitLtep, pp.ExtensionBitFast)
}
Expand Down
2 changes: 1 addition & 1 deletion peer_protocol/extended.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type (
}

ExtensionName string
ExtensionNumber int
ExtensionNumber uint8
)

const (
Expand Down
55 changes: 48 additions & 7 deletions peerconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ type PeerConn struct {
PeerID PeerID
PeerExtensionBytes pp.PeerExtensionBits
PeerListenPort int
// 1-based mapping from extension number to extension name (subtract one from the extension ID
// to find the corresponding protocol name). The first LocalLtepProtocolBuiltinCount of these
// are use builtin handlers. If you want to handle builtin protocols yourself, you would move
// them above the threshold. You can disable them by removing them entirely, and add your own.
// These changes should be done in the PeerConnAdded callback.
LocalLtepProtocolMap []pp.ExtensionName
// How many of the protocols are using the builtin handlers.
LocalLtepProtocolBuiltinCount int

// The actual Conn, used for closing, and setting socket options. Do not use methods on this
// while holding any mutexes.
Expand All @@ -55,6 +63,7 @@ type PeerConn struct {

messageWriter peerConnMsgWriter

// The peer's extension map, as sent in their extended handshake.
PeerExtensionIDs map[pp.ExtensionName]pp.ExtensionNumber
PeerClientName atomic.Value
uploadTimer *time.Timer
Expand Down Expand Up @@ -869,8 +878,17 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
}()
t := c.t
cl := t.cl
switch id {
case pp.HandshakeExtendedID:
{
event := PeerConnReadExtensionMessageEvent{
PeerConn: c,
ExtensionNumber: id,
Payload: payload,
}
for _, cb := range c.callbacks.PeerConnReadExtensionMessage {
cb(event)
}
}
if id == pp.HandshakeExtendedID {
var d pp.ExtendedHandshakeMessage
if err := bencode.Unmarshal(payload, &d); err != nil {
c.logger.Printf("error parsing extended handshake message %q: %s", payload, err)
Expand All @@ -879,7 +897,6 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
if cb := c.callbacks.ReadExtendedHandshake; cb != nil {
cb(c, &d)
}
// c.logger.WithDefaultLevel(log.Debug).Printf("received extended handshake message:\n%s", spew.Sdump(d))
if d.Reqq != 0 {
c.PeerMaxRequests = d.Reqq
}
Expand Down Expand Up @@ -911,13 +928,25 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
c.pex.Init(c)
}
return nil
case metadataExtendedId:
}
// Zero was taken care of above.
protocolIndex := int(id - 1)
if protocolIndex >= len(c.LocalLtepProtocolMap) {
return fmt.Errorf("unexpected extended message ID: %v", id)
}
if protocolIndex >= c.LocalLtepProtocolBuiltinCount {
// The message should have been handled by the PeerConnReadExtensionMessage callback.
return nil
}
extensionName := c.LocalLtepProtocolMap[protocolIndex]
switch extensionName {
case pp.ExtensionNameMetadata:
err := cl.gotMetadataExtensionMsg(payload, t, c)
if err != nil {
return fmt.Errorf("handling metadata extension message: %w", err)
}
return nil
case pexExtendedId:
case pp.ExtensionNamePex:
if !c.pex.IsEnabled() {
return nil // or hang-up maybe?
}
Expand All @@ -926,7 +955,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
err = fmt.Errorf("receiving pex message: %w", err)
}
return
case utHolepunchExtendedId:
case utHolepunch.ExtensionName:
var msg utHolepunch.Msg
err = msg.UnmarshalBinary(payload)
if err != nil {
Expand All @@ -936,7 +965,7 @@ func (c *PeerConn) onReadExtendedMsg(id pp.ExtensionNumber, payload []byte) (err
err = c.t.handleReceivedUtHolepunchMsg(msg, c)
return
default:
return fmt.Errorf("unexpected extended message ID: %v", id)
panic(fmt.Sprintf("unhandled builtin extension protocol %q", extensionName))
}
}

Expand Down Expand Up @@ -1144,3 +1173,15 @@ func (c *PeerConn) useful() bool {
}
return false
}

func (c *PeerConn) addBuiltinLtepProtocols(pex bool) {
ps := []pp.ExtensionName{pp.ExtensionNameMetadata, utHolepunch.ExtensionName}
if pex {
ps = append(ps, pp.ExtensionNamePex)
}
if c.LocalLtepProtocolMap != nil {
panic("already set")
}
c.LocalLtepProtocolMap = ps
c.LocalLtepProtocolBuiltinCount = len(ps)
}
5 changes: 3 additions & 2 deletions pexconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestPexConnState(t *testing.T) {
network: addr.Network(),
})
c.PeerExtensionIDs = make(map[pp.ExtensionName]pp.ExtensionNumber)
c.PeerExtensionIDs[pp.ExtensionNamePex] = pexExtendedId
c.PeerExtensionIDs[pp.ExtensionNamePex] = 1
c.messageWriter.mu.Lock()
c.setTorrent(torrent)
if err := torrent.addPeerConn(c); err != nil {
Expand All @@ -45,7 +45,8 @@ func TestPexConnState(t *testing.T) {
c.pex.Share(testWriter)
require.True(t, writerCalled)
require.EqualValues(t, pp.Extended, out.Type)
require.EqualValues(t, pexExtendedId, out.ExtendedID)
require.NotEqualValues(t, out.ExtendedID, 0)
require.EqualValues(t, c.PeerExtensionIDs[pp.ExtensionNamePex], out.ExtendedID)

x, err := pp.LoadPexMsg(out.ExtendedPayload)
require.NoError(t, err)
Expand Down

0 comments on commit 7f85744

Please sign in to comment.