diff --git a/go/codec.go b/go/codec.go index 68bdc67..685f0a0 100644 --- a/go/codec.go +++ b/go/codec.go @@ -88,7 +88,8 @@ type virtualCodec struct { physicalConn *link.Session connID uint32 recvChan chan []byte - closeOnce sync.Once + closeMutex sync.Mutex + closed bool lastActive *int64 format MsgFormat } @@ -104,6 +105,24 @@ func (p *protocol) newVirtualCodec(physicalConn *link.Session, connID uint32, re } } +func (c *virtualCodec) forward(buf []byte) { + c.closeMutex.Lock() + if c.closed { + c.closeMutex.Unlock() + c.free(buf) + return + } + select { + case c.recvChan <- buf: + c.closeMutex.Unlock() + return + default: + c.closeMutex.Unlock() + c.Close() + c.free(buf) + } +} + func (c *virtualCodec) Receive() (interface{}, error) { buf, ok := <-c.recvChan if !ok { @@ -137,10 +156,14 @@ func (c *virtualCodec) Send(msg interface{}) error { } func (c *virtualCodec) Close() error { - c.closeOnce.Do(func() { + c.closeMutex.Lock() + if !c.closed { + c.closed = true close(c.recvChan) c.send(c.physicalConn, c.encodeCloseCmd(c.connID)) - }) + } + c.closeMutex.Unlock() + for buf := range c.recvChan { c.free(buf) } diff --git a/go/codec_test.go b/go/codec_test.go index c41974c..f47f693 100644 --- a/go/codec_test.go +++ b/go/codec_test.go @@ -171,3 +171,23 @@ func Test_BadVirtualCodec(t *testing.T) { vcodec.recvChan <- bigMsg vcodec.Close() } + +func Test_VirtualCodecReceivcBlock(t *testing.T) { + conn, err := net.Dial("tcp", TestAddr) + utest.IsNilNow(t, err) + defer conn.Close() + + codec := TestProto.newCodec(0, conn, 1024) + pconn := link.NewSession(codec, 1000) + + var lastActive int64 + recvChanSize := 2 + vcodec := TestProto.newVirtualCodec(pconn, 123, recvChanSize, &lastActive, &TestMsgFormat{}) + buf := make([]byte, 100) + for i := 0; i <= recvChanSize; i++ { + vcodec.forward(buf) + } + vcodec.closeMutex.Lock() + defer vcodec.closeMutex.Unlock() + utest.Assert(t, vcodec.closed) +} diff --git a/go/endpoint.go b/go/endpoint.go index 065004a..10a9c8a 100755 --- a/go/endpoint.go +++ b/go/endpoint.go @@ -255,15 +255,11 @@ func (p *EndPoint) loop() { vconn := p.virtualConns.Get(connID) if vconn != nil { - select { - case vconn.Codec().(*virtualCodec).recvChan <- buf: - continue - default: - vconn.Close() - } + vconn.Codec().(*virtualCodec).forward(buf) + } else { + p.free(buf) + p.send(p.session, p.encodeCloseCmd(connID)) } - p.free(buf) - p.send(p.session, p.encodeCloseCmd(connID)) } } diff --git a/go/gateway_test.go b/go/gateway_test.go index de5bd6d..cd7c83b 100755 --- a/go/gateway_test.go +++ b/go/gateway_test.go @@ -353,3 +353,48 @@ func Test_BadEndPoint(t *testing.T) { _, err = DialServer("tcp", lsn3.Addr().String(), TestEndPointCfg) utest.NotNilNow(t, err) } + +func Test_VConnSimultaneouslyCloseAndReceive(t *testing.T) { + lsn1, err := net.Listen("tcp", "127.0.0.1:0") + utest.IsNilNow(t, err) + defer lsn1.Close() + + lsn2, err := net.Listen("tcp", "127.0.0.1:0") + utest.IsNilNow(t, err) + defer lsn2.Close() + + gw := NewGateway(TestPool, TestMaxPacket) + + go gw.ServeClients(lsn1, TestGatewayCfg) + go gw.ServeServers(lsn2, TestGatewayCfg) + + time.Sleep(time.Second) + + server, err := DialServer("tcp", lsn2.Addr().String(), TestEndPointCfg) + utest.IsNilNow(t, err) + time.Sleep(time.Second) + go func() { + for { + vconn, err := server.Accept() + if err != nil { + return + } + runtime.Gosched() + vconn.Close() + runtime.Gosched() + } + }() + time.Sleep(time.Second) + payload := make([]byte, 10) + for i := 0; i < 10000; i++ { + client, err := DialClient("tcp", lsn1.Addr().String(), TestEndPointCfg) + utest.IsNilNow(t, err) + vconn, err := client.Dial(123) + utest.IsNilNow(t, err) + runtime.Gosched() + vconn.Send(payload) + runtime.Gosched() + client.Close() + } + gw.Stop() +}