-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathstream2_sink.go
138 lines (112 loc) · 2.61 KB
/
stream2_sink.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// SPDX-FileCopyrightText: 2021 Henry Bubert
//
// SPDX-License-Identifier: MIT
package muxrpc
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/ssbc/go-muxrpc/v2/codec"
)
type ByteSinker interface {
io.WriteCloser
// sometimes we want to close a query early before it is drained
// (this sends a EndErr packet back )
CloseWithError(error) error
SetEncoding(re RequestEncoding)
}
var _ ByteSinker = (*ByteSink)(nil)
// ByteSink exposes a WriteCloser which wrapps each write into a muxrpc packet for that stream with the correct flags set.
type ByteSink struct {
w *codec.Writer
closedMu sync.Mutex
closed error
streamCtx context.Context
pkt codec.Packet
}
func newByteSink(ctx context.Context, w *codec.Writer) *ByteSink {
return &ByteSink{
streamCtx: ctx,
w: w,
pkt: codec.Packet{},
}
}
func (bs *ByteSink) SetEncoding(re RequestEncoding) {
bs.closedMu.Lock()
defer bs.closedMu.Unlock()
encFlag, err := re.asCodecFlag()
if err != nil {
panic(err)
}
if re == TypeBinary {
bs.pkt.Flag = bs.pkt.Flag.Clear(codec.FlagJSON)
}
bs.pkt.Flag = bs.pkt.Flag.Set(encFlag)
}
func (bs *ByteSink) Write(b []byte) (int, error) {
bs.closedMu.Lock()
defer bs.closedMu.Unlock()
if bs.closed != nil {
return 0, bs.closed
}
// check if the sink was closed since the last write
select {
case <-bs.streamCtx.Done():
bs.closed = bs.streamCtx.Err()
return 0, bs.closed
default:
// no? go on and write!
}
if bs.pkt.Req == 0 {
return -1, fmt.Errorf("req ID not set (Flag: %s)", bs.pkt.Flag)
}
bs.pkt.Body = b
err := bs.w.WritePacket(bs.pkt)
if err != nil {
bs.closed = err
return -1, err
}
return len(b), nil
}
func (bs *ByteSink) CloseWithError(err error) error {
bs.closedMu.Lock()
defer bs.closedMu.Unlock()
if bs.closed != nil {
return bs.closed
}
var closePkt codec.Packet
var isStream = bs.pkt.Flag.Get(codec.FlagStream)
if err == io.EOF || err == nil {
closePkt = newEndOkayPacket(bs.pkt.Req, isStream)
} else {
var epkt error
closePkt, epkt = newEndErrPacket(bs.pkt.Req, isStream, err)
if epkt != nil {
return fmt.Errorf("close bytesink: error building error packet for %s: %w", err, epkt)
}
bs.closed = err
}
// tollerate timeout in writing closed packets
var errc = make(chan error)
go func() {
errc <- bs.w.WritePacket(closePkt)
}()
select {
case werr := <-errc:
if werr != nil {
bs.closed = werr
}
return werr
case <-time.After(10 * time.Second):
bs.closed = errors.New("muxrpc: close timeout exceeded")
return bs.closed
}
bs.closed = err
return nil
}
func (bs *ByteSink) Close() error {
return bs.CloseWithError(io.EOF)
}