Skip to content

Commit

Permalink
Implement ack low water marks
Browse files Browse the repository at this point in the history
  • Loading branch information
plorenz committed Sep 19, 2024
1 parent 622e75c commit 36f8225
Show file tree
Hide file tree
Showing 13 changed files with 872 additions and 106 deletions.
6 changes: 5 additions & 1 deletion router/forwarder/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,16 @@ func (forwarder *Forwarder) ForwardPayload(srcAddr xgress.Address, payload *xgre
}

func (forwarder *Forwarder) RetransmitPayload(srcAddr xgress.Address, payload *xgress.Payload) error {
return forwarder.forwardPayload(srcAddr, payload, false, time.Second)
return forwarder.forwardPayload(srcAddr, payload, false, 0)
}

func (forwarder *Forwarder) forwardPayload(srcAddr xgress.Address, payload *xgress.Payload, markActive bool, timeout time.Duration) error {
log := pfxlog.ContextLogger(string(srcAddr))

if payload.IsCircuitEndFlagSet() {
pfxlog.Logger().Info("forwarding end-of-circuit")
}

circuitId := payload.GetCircuitId()
if forwardTable, found := forwarder.circuits.getForwardTable(circuitId, markActive); found {
if dstAddr, found := forwardTable.getForwardAddress(srcAddr); found {
Expand Down
1 change: 1 addition & 0 deletions router/handler_link/payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func (self *payloadHandler) HandleReceive(msg *channel.Message, ch channel.Chann
self.forwarder.ReportForwardingFault(payload.CircuitId, "")
}
if payload.IsCircuitEndFlagSet() {
pfxlog.Logger().Info("received end-of-circuit")
self.forwarder.EndCircuit(payload.GetCircuitId())
}
} else {
Expand Down
27 changes: 23 additions & 4 deletions router/xgress/link_receive_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
type LinkReceiveBuffer struct {
tree *btree.Tree
sequence int32
lowWaterMark int32
maxSequence int32
size uint32
lastBufferSizeSent uint32
Expand All @@ -45,15 +46,15 @@ func (buffer *LinkReceiveBuffer) Size() uint32 {
return atomic.LoadUint32(&buffer.size)
}

func (buffer *LinkReceiveBuffer) ReceiveUnordered(payload *Payload, maxSize uint32) bool {
func (buffer *LinkReceiveBuffer) ReceiveUnordered(payload *Payload, maxSize uint32) (bool, int32) {
if payload.GetSequence() <= buffer.sequence {
duplicatePayloadsMeter.Mark(1)
return true
return true, buffer.lowWaterMark
}

if atomic.LoadUint32(&buffer.size) > maxSize && payload.Sequence > buffer.maxSequence {
droppedPayloadsMeter.Mark(1)
return false
return false, 0
}

treeSize := buffer.tree.Size()
Expand All @@ -68,7 +69,25 @@ func (buffer *LinkReceiveBuffer) ReceiveUnordered(payload *Payload, maxSize uint
} else {
duplicatePayloadsMeter.Mark(1)
}
return true

if payload.Sequence <= buffer.lowWaterMark {
return true, buffer.lowWaterMark
}

if payload.Sequence == buffer.lowWaterMark+1 {
buffer.lowWaterMark++
for buffer.canHighWaterMarkBeAdvanced() {
buffer.lowWaterMark++
}
return true, buffer.lowWaterMark
}

return true, buffer.lowWaterMark
}

func (buffer *LinkReceiveBuffer) canHighWaterMarkBeAdvanced() bool {
_, found := buffer.tree.Get(buffer.lowWaterMark + 1)
return found
}

func (buffer *LinkReceiveBuffer) PeekHead() *Payload {
Expand Down
80 changes: 52 additions & 28 deletions router/xgress/link_send_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type LinkSendBuffer struct {
lastRetransmitTime int64
closeWhenEmpty atomic.Bool
inspectRequests chan *sendBufferInspectEvent
minSeq int32
}

type txPayload struct {
Expand Down Expand Up @@ -276,35 +277,9 @@ func (buffer *LinkSendBuffer) close() {

func (buffer *LinkSendBuffer) receiveAcknowledgement(ack *Acknowledgement) {
log := pfxlog.ContextLogger(buffer.x.Label()).WithFields(ack.GetLoggerFields())

for _, sequence := range ack.Sequence {
if txPayload, found := buffer.buffer[sequence]; found {
if txPayload.markAcked() { // if it's been queued for retransmission, remove it from the queue
retransmitter.queue(txPayload)
}

payloadSize := uint32(len(txPayload.payload.Data))
buffer.accumulator += payloadSize
buffer.successfulAcks++
delete(buffer.buffer, sequence)
atomic.AddInt64(&outstandingPayloads, -1)
atomic.AddInt64(&outstandingPayloadBytes, -int64(payloadSize))
buffer.linkSendBufferSize -= payloadSize
log.Debugf("removing payload %v with size %v. payload buffer size: %v",
txPayload.payload.Sequence, len(txPayload.payload.Data), buffer.linkSendBufferSize)

if buffer.successfulAcks >= buffer.x.Options.TxPortalIncreaseThresh {
buffer.successfulAcks = 0
delta := uint32(float64(buffer.accumulator) * buffer.x.Options.TxPortalIncreaseScale)
buffer.windowsSize += delta
if buffer.windowsSize > buffer.x.Options.TxPortalMaxSize {
buffer.windowsSize = buffer.x.Options.TxPortalMaxSize
}
buffer.retxScale -= 0.01
if buffer.retxScale < buffer.x.Options.RetxScale {
buffer.retxScale = buffer.x.Options.RetxScale
}
}
if payload, found := buffer.buffer[sequence]; found {
buffer.markPayloadComplete(payload, sequence, log)
} else { // duplicate ack
duplicateAcksMeter.Mark(1)
buffer.duplicateAcks++
Expand All @@ -315,6 +290,10 @@ func (buffer *LinkSendBuffer) receiveAcknowledgement(ack *Acknowledgement) {
}
}

if ack.LowWaterMark > 0 {
buffer.processLowWatermark(ack, log)
}

buffer.linkRecvBufferSize = ack.RecvBufferSize
if ack.RTT > 0 {
rtt := uint16(info.NowInMilliseconds()) - ack.RTT
Expand All @@ -326,6 +305,51 @@ func (buffer *LinkSendBuffer) receiveAcknowledgement(ack *Acknowledgement) {
}
}

func (buffer *LinkSendBuffer) processLowWatermark(ack *Acknowledgement, log *logrus.Entry) {
//fmt.Printf("checking minseq: %d, lowWater: %d\n", buffer.minSeq, ack.LowWaterMark)
for sequence := buffer.minSeq; sequence <= ack.LowWaterMark; sequence++ {
//fmt.Printf("checking minseq: %d, lowWater: %d, seq: %d\n", buffer.minSeq, ack.LowWaterMark, sequence)
if payload, found := buffer.buffer[sequence]; found {
buffer.markPayloadComplete(payload, sequence, log)
}
}
if buffer.minSeq <= ack.LowWaterMark+1 {
buffer.minSeq = ack.LowWaterMark + 1
}
}

func (buffer *LinkSendBuffer) markPayloadComplete(payload *txPayload, sequence int32, log *logrus.Entry) {
if payload.markAcked() { // if it's been queued for retransmission, remove it from the queue
retransmitter.queue(payload)
}

payloadSize := uint32(len(payload.payload.Data))
buffer.accumulator += payloadSize
delete(buffer.buffer, sequence)
atomic.AddInt64(&outstandingPayloads, -1)
atomic.AddInt64(&outstandingPayloadBytes, -int64(payloadSize))
buffer.linkSendBufferSize -= payloadSize
log.Debugf("removing payload %v with size %v. payload buffer size: %v",
payload.payload.Sequence, len(payload.payload.Data), buffer.linkSendBufferSize)

if buffer.successfulAcks >= buffer.x.Options.TxPortalIncreaseThresh {
buffer.successfulAcks = 0
delta := uint32(float64(buffer.accumulator) * buffer.x.Options.TxPortalIncreaseScale)
buffer.windowsSize += delta
if buffer.windowsSize > buffer.x.Options.TxPortalMaxSize {
buffer.windowsSize = buffer.x.Options.TxPortalMaxSize
}
buffer.retxScale -= 0.01
if buffer.retxScale < buffer.x.Options.RetxScale {
buffer.retxScale = buffer.x.Options.RetxScale
}
}

if sequence == buffer.minSeq {
buffer.minSeq = buffer.minSeq + 1
}
}

func (buffer *LinkSendBuffer) retransmit() {
now := info.NowInMilliseconds()
if len(buffer.buffer) > 0 && (now-buffer.lastRetransmitTime) > 64 {
Expand Down
125 changes: 60 additions & 65 deletions router/xgress/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ const (
HeaderKeyFlags = 2258
HeaderKeyRecvBufferSize = 2259
HeaderKeyRTT = 2260
HeaderPayloadRaw = 2261
HeaderKeyLowWaterMark = 2262

ContentTypePayloadType = 1100
ContentTypeAcknowledgementType = 1101
Expand All @@ -62,80 +64,46 @@ func (o Originator) String() string {
return "Terminator"
}

type PayloadFlag uint32
type Flag uint32

const (
PayloadFlagCircuitEnd PayloadFlag = 1
PayloadFlagOriginator PayloadFlag = 2
PayloadFlagCircuitStart PayloadFlag = 4
PayloadFlagChunk PayloadFlag = 8
PayloadFlagCircuitEnd Flag = 1
PayloadFlagOriginator Flag = 2
PayloadFlagCircuitStart Flag = 4
PayloadFlagChunk Flag = 8
)

type Header struct {
func NewAcknowledgement(circuitId string, originator Originator) *Acknowledgement {
return &Acknowledgement{
CircuitId: circuitId,
Flags: SetOriginatorFlag(0, originator),
}
}

type Acknowledgement struct {
CircuitId string
Flags uint32
RecvBufferSize uint32
RTT uint16
Sequence []int32
LowWaterMark int32
}

func (header *Header) GetCircuitId() string {
return header.CircuitId
func (ack *Acknowledgement) GetCircuitId() string {
return ack.CircuitId
}

func (header *Header) GetFlags() uint32 {
return header.Flags
func (ack *Acknowledgement) GetFlags() uint32 {
return ack.Flags
}

func (header *Header) GetOriginator() Originator {
if isPayloadFlagSet(header.Flags, PayloadFlagOriginator) {
func (ack *Acknowledgement) GetOriginator() Originator {
if isFlagSet(ack.Flags, PayloadFlagOriginator) {
return Terminator
}
return Initiator
}

func (header *Header) unmarshallHeader(msg *channel.Message) error {
circuitId, ok := msg.Headers[HeaderKeyCircuitId]
if !ok {
return fmt.Errorf("no circuitId found in xgress payload message")
}

// If no flags are present, it just means no flags have been set
flags, _ := msg.GetUint32Header(HeaderKeyFlags)

header.CircuitId = string(circuitId)
header.Flags = flags
if header.RecvBufferSize, ok = msg.GetUint32Header(HeaderKeyRecvBufferSize); !ok {
header.RecvBufferSize = math.MaxUint32
}

header.RTT, _ = msg.GetUint16Header(HeaderKeyRTT)

return nil
}

func (header *Header) marshallHeader(msg *channel.Message) {
msg.Headers[HeaderKeyCircuitId] = []byte(header.CircuitId)
if header.Flags != 0 {
msg.PutUint32Header(HeaderKeyFlags, header.Flags)
}

msg.PutUint32Header(HeaderKeyRecvBufferSize, header.RecvBufferSize)
}

func NewAcknowledgement(circuitId string, originator Originator) *Acknowledgement {
return &Acknowledgement{
Header: Header{
CircuitId: circuitId,
Flags: SetOriginatorFlag(0, originator),
},
}
}

type Acknowledgement struct {
Header
Sequence []int32
}

func (ack *Acknowledgement) GetSequence() []int32 {
return ack.Sequence
}
Expand Down Expand Up @@ -174,20 +142,43 @@ func (ack *Acknowledgement) unmarshallSequence(data []byte) error {
func (ack *Acknowledgement) Marshall() *channel.Message {
msg := channel.NewMessage(ContentTypeAcknowledgementType, ack.marshallSequence())
msg.PutUint16Header(HeaderKeyRTT, ack.RTT)
ack.marshallHeader(msg)
msg.Headers[HeaderKeyCircuitId] = []byte(ack.CircuitId)
if ack.Flags != 0 {
msg.PutUint32Header(HeaderKeyFlags, ack.Flags)
}
msg.PutUint32Header(HeaderKeyRecvBufferSize, ack.RecvBufferSize)
if ack.LowWaterMark > 0 {
msg.PutUint32Header(HeaderKeyLowWaterMark, uint32(ack.LowWaterMark))
}
return msg
}

func UnmarshallAcknowledgement(msg *channel.Message) (*Acknowledgement, error) {
ack := &Acknowledgement{}

if err := ack.unmarshallHeader(msg); err != nil {
return nil, err
circuitId, ok := msg.Headers[HeaderKeyCircuitId]
if !ok {
return nil, fmt.Errorf("no circuitId found in xgress payload message")
}

// If no flags are present, it just means no flags have been set
flags, _ := msg.GetUint32Header(HeaderKeyFlags)

ack.CircuitId = string(circuitId)
ack.Flags = flags
if ack.RecvBufferSize, ok = msg.GetUint32Header(HeaderKeyRecvBufferSize); !ok {
ack.RecvBufferSize = math.MaxUint32
}

ack.RTT, _ = msg.GetUint16Header(HeaderKeyRTT)

if err := ack.unmarshallSequence(msg.Body); err != nil {
return nil, err
}

lowWaterMark, _ := msg.GetUint32Header(HeaderKeyLowWaterMark)
ack.LowWaterMark = int32(lowWaterMark)

return ack, nil
}

Expand Down Expand Up @@ -290,15 +281,19 @@ func UnmarshallPayload(msg *channel.Message) (*Payload, error) {
}
payload.Sequence = int32(sequence)

if raw, ok := msg.Headers[HeaderPayloadRaw]; ok {
payload.raw = raw
}

return payload, nil
}

func isPayloadFlagSet(flags uint32, flag PayloadFlag) bool {
return PayloadFlag(flags)&flag == flag
func isFlagSet(flags uint32, flag Flag) bool {
return Flag(flags)&flag == flag
}

func setPayloadFlag(flags uint32, flag PayloadFlag) uint32 {
return uint32(PayloadFlag(flags) | flag)
func setPayloadFlag(flags uint32, flag Flag) uint32 {
return uint32(Flag(flags) | flag)
}

func (payload *Payload) GetCircuitId() string {
Expand All @@ -310,15 +305,15 @@ func (payload *Payload) GetFlags() uint32 {
}

func (payload *Payload) IsCircuitEndFlagSet() bool {
return isPayloadFlagSet(payload.Flags, PayloadFlagCircuitEnd)
return isFlagSet(payload.Flags, PayloadFlagCircuitEnd)
}

func (payload *Payload) IsCircuitStartFlagSet() bool {
return isPayloadFlagSet(payload.Flags, PayloadFlagCircuitStart)
return isFlagSet(payload.Flags, PayloadFlagCircuitStart)
}

func (payload *Payload) GetOriginator() Originator {
if isPayloadFlagSet(payload.Flags, PayloadFlagOriginator) {
if isFlagSet(payload.Flags, PayloadFlagOriginator) {
return Terminator
}
return Initiator
Expand Down
Loading

0 comments on commit 36f8225

Please sign in to comment.