Skip to content

Commit

Permalink
fix: track command sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
jeroenrinzema committed Sep 1, 2023
1 parent 40b518e commit 746b2d3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
21 changes: 16 additions & 5 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *b
return err
}

// NOTE: we increase the wait group by one in order to make sure that idle
// connections are not blocking a close.
srv.wg.Add(1)

srv.logger.Debug("incoming command", slog.Int("length", length), slog.String("type", string(t)))
err = srv.handleCommand(ctx, conn, t, reader, writer)
if errors.Is(err, io.EOF) {
Expand Down Expand Up @@ -108,19 +104,34 @@ func (srv *Server) handleMessageSizeExceeded(reader *buffer.Reader, writer *buff
return ErrorCode(writer, exceeded)
}

func (srv *Server) startSequence() {
srv.mu.Lock()
srv.wg.Add(1)
srv.mu.Unlock()
}

func (srv *Server) endSequence() {
srv.mu.Lock()
srv.wg.Done()
srv.mu.Unlock()
}

// handleCommand handles the given client message. A client message includes a
// message type and reader buffer containing the actual message. The type
// indecates a action executed by the client.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) (err error) {
defer srv.wg.Done()
ctx, cancel := context.WithCancel(ctx)
defer cancel()

switch t {
case types.ClientSimpleQuery:
srv.startSequence()
defer srv.endSequence()
return srv.handleSimpleQuery(ctx, reader, writer)
case types.ClientExecute:
srv.startSequence()
defer srv.endSequence()
return srv.handleExecute(ctx, reader, writer)
case types.ClientParse:
return srv.handleParse(ctx, reader, writer)
Expand Down
1 change: 1 addition & 0 deletions wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func NewServer(parse ParseFn, options ...OptionFn) (*Server, error) {
// Server contains options for listening to an address.
type Server struct {
wg sync.WaitGroup
mu sync.RWMutex
logger *slog.Logger
types *pgtype.ConnInfo
Auth AuthStrategy
Expand Down

0 comments on commit 746b2d3

Please sign in to comment.