-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathstream2_source.go
292 lines (235 loc) · 6.11 KB
/
stream2_source.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
// SPDX-FileCopyrightText: 2021 Henry Bubert
//
// SPDX-License-Identifier: MIT
package muxrpc
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"sync"
"sync/atomic"
"github.com/karrick/bufpool"
"github.com/ssbc/go-muxrpc/v2/codec"
)
// ReadFn is what a ByteSource needs for it's ReadFn. The passed reader is only valid during the call to it.
type ReadFn func(r io.Reader) error
type ByteSourcer interface {
Next(context.Context) bool
Reader(ReadFn) error
// sometimes we want to close a query early before it is drained
// (this sends a EndErr packet back )
Cancel(error)
}
var _ ByteSourcer = (*ByteSource)(nil)
// ByteSource is inspired by sql.Rows but without the Scan(), it just reads plain []bytes, one per muxrpc packet.
type ByteSource struct {
bpool bufpool.FreeList
buf *frameBuffer
mu sync.Mutex
closed chan struct{}
failed error
hdrFlag codec.Flag
streamCtx context.Context
cancel context.CancelFunc
}
func newByteSource(ctx context.Context, pool bufpool.FreeList) *ByteSource {
bs := &ByteSource{
bpool: pool,
buf: &frameBuffer{
store: pool.Get(),
},
closed: make(chan struct{}),
}
bs.streamCtx, bs.cancel = context.WithCancel(ctx)
return bs
}
// Cancel stops reading and terminates the request.
// Sometimes we want to close a query early before it is drained.
func (bs *ByteSource) Cancel(err error) {
bs.mu.Lock()
defer bs.mu.Unlock()
if bs.failed != nil {
return
}
if err == nil {
bs.failed = io.EOF
} else {
bs.failed = err
}
close(bs.closed)
}
// Err returns nill or an error when processing fails or the context was canceled
func (bs *ByteSource) Err() error {
bs.mu.Lock()
defer bs.mu.Unlock()
if errors.Is(bs.failed, io.EOF) || errors.Is(bs.failed, context.Canceled) {
return nil
}
return bs.failed
}
// Next blocks until there are new muxrpc frames for this stream
func (bs *ByteSource) Next(ctx context.Context) bool {
bs.mu.Lock()
if bs.failed != nil && bs.buf.frames == 0 {
// don't return buffer before stream is empty
// TODO: what if a stream isn't fully drained?!
bs.bpool.Put(bs.buf.store)
bs.mu.Unlock()
return false
}
if bs.buf.frames > 0 {
bs.mu.Unlock()
return true
}
bs.mu.Unlock()
select {
case <-bs.streamCtx.Done():
bs.mu.Lock()
defer bs.mu.Unlock()
if bs.failed == nil {
bs.failed = bs.streamCtx.Err()
}
return bs.buf.Frames() > 0
case <-ctx.Done():
bs.mu.Lock()
defer bs.mu.Unlock()
if bs.failed == nil {
bs.failed = ctx.Err()
}
return false
case <-bs.closed:
return bs.buf.Frames() > 0
case <-bs.buf.waitForMore():
return true
}
}
// Reader passes a (limited) reader for the next segment to the passed .
// Since the stream can't be written while it's read, the reader is only valid during the call to the passed function.
func (bs *ByteSource) Reader(fn ReadFn) error {
_, rd, err := bs.buf.getNextFrameReader()
if err != nil {
return err
}
bs.buf.mu.Lock()
err = fn(rd)
bs.buf.mu.Unlock()
return err
}
// Bytes returns the full slice of bytes from the next frame.
func (bs *ByteSource) Bytes() ([]byte, error) {
_, rd, err := bs.buf.getNextFrameReader()
if err != nil {
return nil, err
}
bs.buf.mu.Lock()
b, err := ioutil.ReadAll(rd)
bs.buf.mu.Unlock()
return b, err
}
func (bs *ByteSource) consume(pktLen uint32, flag codec.Flag, r io.Reader) error {
bs.mu.Lock()
defer bs.mu.Unlock()
if bs.failed != nil {
return fmt.Errorf("muxrpc: byte source canceled: %w", bs.failed)
}
bs.hdrFlag = flag
err := bs.buf.copyBody(pktLen, r)
if err != nil {
return err
}
return nil
}
// utils
// frame buffer: a buffer frames and a frame is length+body.
// it stores muxrpc body packets with their length as one contiguous stream in a bytes.Buffer
type frameBuffer struct {
mu sync.Mutex
store *bytes.Buffer
// TODO[weird-chans]: why exactly do you need a list of channels here
waiting []chan<- struct{}
// how much of the current frame has been read
// to advance/skip store correctly
currentFrameTotal uint32
currentFrameRead uint32
frames uint32
lenBuf [4]byte
}
func (fb *frameBuffer) Frames() uint32 {
return atomic.LoadUint32(&fb.frames)
}
func (fb *frameBuffer) copyBody(pktLen uint32, rd io.Reader) error {
fb.mu.Lock()
defer fb.mu.Unlock()
binary.LittleEndian.PutUint32(fb.lenBuf[:], uint32(pktLen))
fb.store.Write(fb.lenBuf[:])
copied, err := io.Copy(fb.store, rd)
if err != nil {
return err
}
if uint32(copied) != pktLen {
return errors.New("frameBuffer: failed to consume whole body")
}
atomic.AddUint32(&fb.frames, 1)
// TODO[weird-chans]: why exactly do you need a list of channels here
if n := len(fb.waiting); n > 0 {
for _, ch := range fb.waiting {
close(ch)
}
fb.waiting = make([]chan<- struct{}, 0)
}
return nil
}
func (fb *frameBuffer) waitForMore() <-chan struct{} {
fb.mu.Lock()
defer fb.mu.Unlock()
// TODO: maybe retrn nil to signal this instead of allocating channels that are immediatly closed?
ch := make(chan struct{})
if fb.frames > 0 {
close(ch)
return ch
}
// TODO[weird-chans]: why exactly do you need a list of channels here
fb.waiting = append(fb.waiting, ch)
return ch
}
func (fb *frameBuffer) getNextFrameReader() (uint32, io.Reader, error) {
fb.mu.Lock()
defer fb.mu.Unlock()
if fb.currentFrameTotal != 0 {
// if the last frame hasn't been fully read
diff := int64(fb.currentFrameTotal - fb.currentFrameRead)
if diff > 0 {
// seek it into /dev/null
io.Copy(ioutil.Discard, io.LimitReader(fb.store, diff))
}
}
_, err := fb.store.Read(fb.lenBuf[:])
if err != nil {
return 0, nil, fmt.Errorf("muxrpc: didnt get length of next body (frames:%d): %w", fb.frames, err)
}
pktLen := binary.LittleEndian.Uint32(fb.lenBuf[:])
fb.currentFrameRead = 0
fb.currentFrameTotal = pktLen
rd := &countingReader{
rd: io.LimitReader(fb.store, int64(pktLen)),
read: &fb.currentFrameRead,
}
// fb.frames--
atomic.AddUint32(&fb.frames, ^uint32(0))
return pktLen, rd, nil
}
type countingReader struct {
rd io.Reader
read *uint32
}
func (cr *countingReader) Read(b []byte) (int, error) {
n, err := cr.rd.Read(b)
if err == nil && n > 0 {
*cr.read += uint32(n)
}
return n, err
}