From 746b2d3c160e0a5ba2093e0fb958db12f9d5dddc Mon Sep 17 00:00:00 2001 From: Jeroen Rinzema Date: Fri, 1 Sep 2023 11:02:53 +0200 Subject: [PATCH] fix: track command sequences --- command.go | 21 ++++++++++++++++----- wire.go | 1 + 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/command.go b/command.go index d00c71e..d40eaad 100644 --- a/command.go +++ b/command.go @@ -69,10 +69,6 @@ func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *b return err } - // NOTE: we increase the wait group by one in order to make sure that idle - // connections are not blocking a close. - srv.wg.Add(1) - srv.logger.Debug("incoming command", slog.Int("length", length), slog.String("type", string(t))) err = srv.handleCommand(ctx, conn, t, reader, writer) if errors.Is(err, io.EOF) { @@ -108,19 +104,34 @@ func (srv *Server) handleMessageSizeExceeded(reader *buffer.Reader, writer *buff return ErrorCode(writer, exceeded) } +func (srv *Server) startSequence() { + srv.mu.Lock() + srv.wg.Add(1) + srv.mu.Unlock() +} + +func (srv *Server) endSequence() { + srv.mu.Lock() + srv.wg.Done() + srv.mu.Unlock() +} + // handleCommand handles the given client message. A client message includes a // message type and reader buffer containing the actual message. The type // indecates a action executed by the client. // https://www.postgresql.org/docs/14/protocol-message-formats.html func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) (err error) { - defer srv.wg.Done() ctx, cancel := context.WithCancel(ctx) defer cancel() switch t { case types.ClientSimpleQuery: + srv.startSequence() + defer srv.endSequence() return srv.handleSimpleQuery(ctx, reader, writer) case types.ClientExecute: + srv.startSequence() + defer srv.endSequence() return srv.handleExecute(ctx, reader, writer) case types.ClientParse: return srv.handleParse(ctx, reader, writer) diff --git a/wire.go b/wire.go index 632bed0..ca472f4 100644 --- a/wire.go +++ b/wire.go @@ -53,6 +53,7 @@ func NewServer(parse ParseFn, options ...OptionFn) (*Server, error) { // Server contains options for listening to an address. type Server struct { wg sync.WaitGroup + mu sync.RWMutex logger *slog.Logger types *pgtype.ConnInfo Auth AuthStrategy