-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathreceiver.go
342 lines (301 loc) · 10.6 KB
/
receiver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
/*
Copyright 2013-Present Couchbase, Inc.
Use of this software is governed by the Business Source License included in
the file licenses/BSL-Couchbase.txt. As of the Change Date specified in that
file, in accordance with the Business Source License, use of this software will
be governed by the Apache License, Version 2.0, included in the file
licenses/APL2.txt.
*/
package blip
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"log"
"runtime/debug"
"sync"
"sync/atomic"
"github.com/coder/websocket"
)
const checksumLength = 4
type msgStreamer struct {
message *Message
writer io.WriteCloser
bytesWritten uint64
}
type msgStreamerMap map[MessageNumber]*msgStreamer
// The receiving side of a BLIP connection.
// Handles receiving WebSocket messages as frames and assembling them into BLIP messages.
type receiver struct {
context *Context // My owning BLIP Context
conn *websocket.Conn // The WebSocket connection
channel chan []byte // WebSocket messages waiting to be processed
numRequestsReceived MessageNumber // The number of REQ messages I've received
sender *Sender // My Context's Sender
frameBuffer bytes.Buffer // Used to stream an incoming frame's data
frameDecoder *decompressor // Decompresses compressed frame bodies
parseError chan error // Fatal error generated by frame parser
activeGoroutines int32 // goroutine counter for safe teardown
pendingMutex sync.Mutex // For thread-safe access to the fields below
pendingRequests msgStreamerMap // Unfinished REQ messages being assembled
pendingResponses msgStreamerMap // Unfinished RES messages being assembled
maxPendingResponseNumber MessageNumber // Largest RES # I've seen
}
func newReceiver(context *Context, conn *websocket.Conn) *receiver {
return &receiver{
conn: conn,
context: context,
channel: make(chan []byte, 10),
parseError: make(chan error, 1),
frameDecoder: getDecompressor(context),
pendingRequests: msgStreamerMap{},
pendingResponses: msgStreamerMap{},
}
}
func (r *receiver) receiveLoop() error {
defer atomic.AddInt32(&r.activeGoroutines, -1)
atomic.AddInt32(&r.activeGoroutines, 1)
go r.parseLoop()
defer close(r.channel)
for {
// Receive the next raw WebSocket frame:
_, frame, err := r.conn.Read(r.context.GetCancelCtx())
if err != nil {
if isCloseError(err) {
// lower log level for close
r.context.logFrame("receiveLoop stopped: %v", err)
} else if parseErr := errorFromChannel(r.parseError); parseErr != nil {
err = parseErr
} else {
r.context.log("Error: receiveLoop exiting with WebSocket error: %v", err)
}
return err
}
r.channel <- frame
}
}
func (r *receiver) parseLoop() {
defer func() { // Panic handler:
atomic.AddInt32(&r.activeGoroutines, -1)
if p := recover(); p != nil {
log.Printf("PANIC in BLIP parseLoop: %v\n%s", p, debug.Stack())
err, _ := p.(error)
if err == nil {
err = fmt.Errorf("Panic: %v", p)
}
r.fatalError(err)
}
}()
// Update Expvar stats for number of outstanding goroutines
incrParseLoopGoroutines()
defer decrParseLoopGoroutines()
atomic.AddInt32(&r.activeGoroutines, 1)
for frame := range r.channel {
r.context.bytesReceived.Add(uint64(len(frame)))
if err := r.handleIncomingFrame(frame); err != nil {
r.fatalError(err)
break
}
}
r.context.logFrame("parseLoop stopped")
returnDecompressor(r.frameDecoder)
r.frameDecoder = nil
}
func (r *receiver) fatalError(err error) {
r.context.log("Error: parseLoop closing socket due to error: %v", err)
r.parseError <- err
r.stop()
}
func (r *receiver) stop() {
r.closePendingResponses()
r.conn.Close(websocket.StatusNormalClosure, "")
waitForZeroActiveGoroutines(r.context, &r.activeGoroutines)
}
func (r *receiver) closePendingResponses() {
r.pendingMutex.Lock()
defer r.pendingMutex.Unlock()
// There can be goroutines spawned by message.asyncRead() that are blocked waiting to
// read off their end of an io.Pipe, and if the peer abruptly closes a connection which causes
// the sender to stop(), the other side of that io.Pipe must be closed to avoid the goroutine's
// call to unblock on the read() call. This loops through any io.Pipewriters in pendingResponses and
// close them, unblocking the readers and letting the message.asyncRead() goroutines proceed.
for _, msgStreamer := range r.pendingResponses {
err := msgStreamer.writer.Close()
if err != nil {
r.context.logMessage("Warning: error closing msgStreamer writer in pending responses while stopping receiver: %v", err)
}
}
}
func (r *receiver) handleIncomingFrame(frame []byte) error {
// Parse BLIP header:
if len(frame) < 2 {
return fmt.Errorf("Illegally short frame")
}
r.frameBuffer.Reset()
r.frameBuffer.Write(frame)
n, err := binary.ReadUvarint(&r.frameBuffer)
if err != nil {
return err
}
requestNumber := MessageNumber(n)
n, err = binary.ReadUvarint(&r.frameBuffer)
if err != nil {
return err
}
flags := frameFlags(n)
msgType := flags.messageType()
if msgType.isAck() {
// ACKs are parsed specially. They don't go through the codec nor contain a checksum:
body := r.frameBuffer.Bytes()
bytesReceived, n := binary.Uvarint(body)
if n > 0 {
r.sender.receivedAck(requestNumber, msgType.ackSourceType(), bytesReceived)
} else {
r.context.log("Error reading ACK frame: %x", body)
}
return nil
} else {
// Regular frames have a checksum:
bufferedFrame := r.frameBuffer.Bytes()
frameSize := len(bufferedFrame)
if len(frame) < checksumLength {
return fmt.Errorf("Illegally short frame")
}
checksumSlice := bufferedFrame[len(bufferedFrame)-checksumLength : len(bufferedFrame)]
checksum := binary.BigEndian.Uint32(checksumSlice)
r.frameBuffer.Truncate(r.frameBuffer.Len() - checksumLength)
if r.context.LogFrames {
r.context.logFrame("Received frame: %s (flags=%8b, length=%d)",
frameString(requestNumber, flags), flags, r.frameBuffer.Len())
}
// Read/decompress the body of the frame:
var body []byte
if flags&kCompressed != 0 {
body, err = r.frameDecoder.decompress(r.frameBuffer.Bytes(), checksum)
} else {
body, err = r.frameDecoder.passthrough(r.frameBuffer.Bytes(), &checksum)
}
if err != nil {
r.context.log("Error receiving frame %s: %v. Raw frame = <%x>",
frameString(requestNumber, flags), err, frame)
return err
}
return r.processFrame(requestNumber, flags, body, frameSize)
}
}
func (r *receiver) processFrame(requestNumber MessageNumber, flags frameFlags, frame []byte, frameSize int) error {
// Look up or create the writer stream for this message:
complete := (flags & kMoreComing) == 0
var msgStream *msgStreamer
var err error
switch flags.messageType() {
case RequestType:
msgStream, err = r.getPendingRequest(requestNumber, flags, complete)
case ResponseType, ErrorType:
msgStream, err = r.getPendingResponse(requestNumber, flags, complete)
case AckRequestType, AckResponseType:
break
default:
r.context.log("Warning: Ignoring incoming message type, with flags 0x%x", flags)
}
// Write the decoded frame body to the stream:
if msgStream != nil {
if _, err := writeFull(frame, msgStream.writer); err != nil {
return err
} else if complete {
if err = msgStream.writer.Close(); err != nil {
r.context.log("Warning: message writer closed with error %v", err)
}
} else {
//FIX: This isn't the right place to do this, because this goroutine doesn't block even
// if the client can't read the message fast enough. The right place to send the ACK is
// in the goroutine that's running msgStream.writer. (Somehow...)
oldWritten := msgStream.bytesWritten
msgStream.bytesWritten += uint64(frameSize)
if oldWritten > 0 && (oldWritten/kAckInterval) < (msgStream.bytesWritten/kAckInterval) {
r.sender.sendAck(requestNumber, flags.messageType(), msgStream.bytesWritten)
}
}
}
return err
}
func (r *receiver) getPendingRequest(requestNumber MessageNumber, flags frameFlags, complete bool) (msgStream *msgStreamer, err error) {
r.pendingMutex.Lock()
defer r.pendingMutex.Unlock()
msgStream = r.pendingRequests[requestNumber]
if msgStream != nil {
if complete {
delete(r.pendingRequests, requestNumber)
}
} else if requestNumber == r.numRequestsReceived+1 {
r.numRequestsReceived++
request := newIncomingMessage(r.sender, requestNumber, flags, nil)
atomic.AddInt32(&r.activeGoroutines, 1)
msgStream = &msgStreamer{
message: request,
writer: request.asyncRead(func(err error) {
r.context.dispatchRequest(request, r.sender)
atomic.AddInt32(&r.activeGoroutines, -1)
}),
}
if !complete {
r.pendingRequests[requestNumber] = msgStream
}
} else {
return nil, fmt.Errorf("Bad incoming request number %d", requestNumber)
}
return msgStream, nil
}
func (r *receiver) getPendingResponse(requestNumber MessageNumber, flags frameFlags, complete bool) (msgStream *msgStreamer, err error) {
r.pendingMutex.Lock()
defer r.pendingMutex.Unlock()
msgStream = r.pendingResponses[requestNumber]
if msgStream != nil {
if msgStream.bytesWritten == 0 {
msgStream.message.flags.Store(&flags) // set flags based on 1st frame of response
}
if complete {
delete(r.pendingResponses, requestNumber)
}
} else if requestNumber <= r.maxPendingResponseNumber {
// sent a request that wasn't expecting a response for?
r.context.log("Warning: Unexpected response frame to my msg #%d", requestNumber) // benign
} else {
// processing a response frame with a message number higher than any requests I've sent
err = fmt.Errorf("Bogus message number %d in response. Expected to be less than max pending response number (%d)", requestNumber, r.maxPendingResponseNumber)
}
return
}
// pendingResponses is accessed from both the receiveLoop goroutine and the sender's goroutine,
// so it needs synchronization.
func (r *receiver) awaitResponse(request *Message, writer io.WriteCloser) {
r.pendingMutex.Lock()
defer r.pendingMutex.Unlock()
number := request.number
r.pendingResponses[number] = &msgStreamer{
message: request,
writer: writer,
}
if number > r.maxPendingResponseNumber {
r.maxPendingResponseNumber = number
}
}
func (r *receiver) backlog() (pendingRequest, pendingResponses int) {
r.pendingMutex.Lock()
defer r.pendingMutex.Unlock()
return len(r.pendingRequests), len(r.pendingResponses)
}
// Why isn't this in the io package already, when ReadFull is?
func writeFull(buf []byte, writer io.Writer) (nWritten int, err error) {
for len(buf) > 0 {
var n int
n, err = writer.Write(buf)
if err != nil {
break
}
nWritten += n
buf = buf[n:]
}
return
}