diff --git a/consumer_group.go b/consumer_group.go index 01cdb669c..91b6e584e 100644 --- a/consumer_group.go +++ b/consumer_group.go @@ -242,6 +242,8 @@ func (c *consumerGroup) ResumeAll() { func (c *consumerGroup) retryNewSession(ctx context.Context, topics []string, handler ConsumerGroupHandler, retries int, refreshCoordinator bool) (*consumerGroupSession, error) { select { + case <-ctx.Done(): + return nil, ctx.Err() case <-c.closed: return nil, ErrClosedConsumerGroup case <-time.After(c.config.Consumer.Group.Rebalance.Retry.Backoff): @@ -261,6 +263,9 @@ func (c *consumerGroup) retryNewSession(ctx context.Context, topics []string, ha } func (c *consumerGroup) newSession(ctx context.Context, topics []string, handler ConsumerGroupHandler, retries int) (*consumerGroupSession, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } coordinator, err := c.client.Coordinator(c.groupID) if err != nil { if retries <= 0 { diff --git a/consumer_group_test.go b/consumer_group_test.go index 5bfdcc8f3..30d615e98 100644 --- a/consumer_group_test.go +++ b/consumer_group_test.go @@ -232,3 +232,15 @@ func TestConsumerGroupSessionDoesNotRetryForever(t *testing.T) { wg.Wait() } + +func TestConsumerShouldNotRetrySessionIfContextCancelled(t *testing.T) { + c := &consumerGroup{ + config: NewTestConfig(), + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := c.newSession(ctx, nil, nil, 1024) + assert.Equal(t, context.Canceled, err) + _, err = c.retryNewSession(ctx, nil, nil, 1024, true) + assert.Equal(t, context.Canceled, err) +}