diff --git a/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala b/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala index cdcf5389800fb..322dccd0dcb37 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala @@ -211,8 +211,7 @@ class RPCProducerIdManager(brokerId: Int, } - // Visible for testing - private[transaction] def maybeRequestNextBlock(): Unit = { + private def maybeRequestNextBlock(): Unit = { val retryTimestamp = backoffDeadlineMs.get() if (retryTimestamp == NoRetry || time.milliseconds() >= retryTimestamp) { // Send a request only if we reached the retry deadline, or if no deadline was set. diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala index 17e35ffde65ac..eef0b31e415a3 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala @@ -35,10 +35,10 @@ import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers.{any, anyString} import org.mockito.Mockito.{mock, when} -import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, Executors, TimeUnit} import scala.collection.mutable -import scala.util.{Failure, Success} +import scala.util.{Failure, Success, Try} class ProducerIdManagerTest { @@ -50,10 +50,9 @@ class ProducerIdManagerTest { val brokerId: Int, var idStart: Long, val idLen: Int, - var error: Errors = Errors.NONE, + val errorQueue: ConcurrentLinkedQueue[Errors] = new ConcurrentLinkedQueue[Errors](), val isErroneousBlock: Boolean = false, - val time: Time = Time.SYSTEM, - var remainingRetries: Int = 1 + val time: Time = Time.SYSTEM ) extends RPCProducerIdManager(brokerId, time, () => 1, brokerToController) { private val brokerToControllerRequestExecutor = Executors.newSingleThreadExecutor() @@ -62,7 +61,8 @@ class ProducerIdManagerTest { override private[transaction] def sendRequest(): Unit = { brokerToControllerRequestExecutor.submit(() => { - if (error == Errors.NONE) { + val error = errorQueue.poll() + if (error == null || error == Errors.NONE) { handleAllocateProducerIdsResponse(new AllocateProducerIdsResponse( new AllocateProducerIdsResponseData().setProducerIdStart(idStart).setProducerIdLen(idLen))) if (!isErroneousBlock) { @@ -79,17 +79,6 @@ class ProducerIdManagerTest { super.handleAllocateProducerIdsResponse(response) capturedFailure.set(nextProducerIdBlock.get == null) } - - override private[transaction] def maybeRequestNextBlock(): Unit = { - if (error == Errors.NONE && !isErroneousBlock) { - super.maybeRequestNextBlock() - } else { - if (remainingRetries > 0) { - super.maybeRequestNextBlock() - remainingRetries -= 1 - } - } - } } @Test @@ -190,15 +179,12 @@ class ProducerIdManagerTest { @EnumSource(value = classOf[Errors], names = Array("UNKNOWN_SERVER_ERROR", "INVALID_REQUEST")) def testUnrecoverableErrors(error: Errors): Unit = { val time = new MockTime() - val manager = new MockProducerIdManager(0, 0, 1, time = time) + val manager = new MockProducerIdManager(0, 0, 1, errorQueue = queue(Errors.NONE, error), time = time) verifyNewBlockAndProducerId(manager, new ProducerIdsBlock(0, 0, 1), 0) - manager.error = error - time.sleep(RetryBackoffMs) verifyFailure(manager) - manager.error = Errors.NONE time.sleep(RetryBackoffMs) verifyNewBlockAndProducerId(manager, new ProducerIdsBlock(0, 1, 1), 1) } @@ -219,19 +205,24 @@ class ProducerIdManagerTest { def testRetryBackoff(): Unit = { val time = new MockTime() val manager = new MockProducerIdManager(0, 0, 1, - error = Errors.UNKNOWN_SERVER_ERROR, time = time, remainingRetries = 2) + errorQueue = queue(Errors.UNKNOWN_SERVER_ERROR), time = time) verifyFailure(manager) - manager.error = Errors.NONE // We should only get a new block once retry backoff ms has passed. - assertEquals(classOf[CoordinatorLoadInProgressException], manager.generateProducerId().failed.get.getClass) + assertCoordinatorLoadInProgressExceptionFailure(manager.generateProducerId()) time.sleep(RetryBackoffMs) verifyNewBlockAndProducerId(manager, new ProducerIdsBlock(0, 0, 1), 0) } + private def queue(errors: Errors*): ConcurrentLinkedQueue[Errors] = { + val queue = new ConcurrentLinkedQueue[Errors]() + errors.foreach(queue.add) + queue + } + private def verifyFailure(manager: MockProducerIdManager): Unit = { - assertEquals(classOf[CoordinatorLoadInProgressException], manager.generateProducerId().failed.get.getClass) + assertCoordinatorLoadInProgressExceptionFailure(manager.generateProducerId()) TestUtils.waitUntilTrue(() => { manager synchronized { manager.capturedFailure.get @@ -244,12 +235,17 @@ class ProducerIdManagerTest { expectedBlock: ProducerIdsBlock, expectedPid: Long): Unit = { - assertEquals(classOf[CoordinatorLoadInProgressException], manager.generateProducerId().failed.get.getClass) + assertCoordinatorLoadInProgressExceptionFailure(manager.generateProducerId()) TestUtils.waitUntilTrue(() => { val nextBlock = manager.nextProducerIdBlock.get nextBlock != null && nextBlock.equals(expectedBlock) }, "failed to generate block") assertEquals(expectedPid, manager.generateProducerId().get) } + + private def assertCoordinatorLoadInProgressExceptionFailure(generatedProducerId: Try[Long]): Unit = { + assertTrue(generatedProducerId.isFailure, () => s"expected failure but got producerId: ${generatedProducerId.get}") + assertEquals(classOf[CoordinatorLoadInProgressException], generatedProducerId.failed.get.getClass) + } }