From b100f1efac77bf795683ab5e68ecf87845372089 Mon Sep 17 00:00:00 2001 From: Dimitar Dimitrov <30328539+dimitarndimitrov@users.noreply.github.com> Date: Tue, 20 Jun 2023 16:50:46 +0200 Subject: [PATCH] KAFKA-15087 Move/rewrite InterBrokerSendThread to server-commons (#13856) The Java rewrite is kept relatively close to the Scala original to minimize potential newly introduced bugs and to make reviewing simpler. The following details might be of note: - The `Logging` trait moved to InterBrokerSendThread with the rewrite of ShutdownableThread has been similarly moved to any subclasses that currently use it. InterBrokerSendThread's own logging has been made to use ShutdownableThread's logger which mimics the prefix/log identifier that the trait provided. - The case RequestAndCompletionHandler class has been made a separate POJO class and the internal-use UnsentRequests class has been kept as a static nested class. - The relatively commonly used but internal (not part of the public API) clients classes that InterBrokerSendThread relies on have been allowlisted in the server-common import control. - The accompanying test class has also been moved and rewritten with one new test added and most of the pre-existing tests made stricter. Reviewers: David Jacot --- checkstyle/import-control-server-common.xml | 9 + .../kafka/common/InterBrokerSendThread.scala | 218 ------------ .../TransactionMarkerChannelManager.scala | 13 +- .../kafka/raft/KafkaNetworkChannel.scala | 21 +- .../server/AddPartitionsToTxnManager.scala | 19 +- .../BrokerToControllerChannelManager.scala | 19 +- .../common/InterBrokerSendThreadTest.scala | 256 -------------- ...ransactionCoordinatorConcurrencyTest.scala | 2 +- .../TransactionMarkerChannelManagerTest.scala | 17 +- .../AddPartitionsToTxnManagerTest.scala | 12 +- .../server/util/InterBrokerSendThread.java | 253 ++++++++++++++ .../util/RequestAndCompletionHandler.java | 51 +++ .../kafka/server/util/ShutdownableThread.java | 2 +- .../util/InterBrokerSendThreadTest.java | 326 ++++++++++++++++++ 14 files changed, 695 insertions(+), 523 deletions(-) delete mode 100644 core/src/main/scala/kafka/common/InterBrokerSendThread.scala delete mode 100644 core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala create mode 100644 server-common/src/main/java/org/apache/kafka/server/util/InterBrokerSendThread.java create mode 100644 server-common/src/main/java/org/apache/kafka/server/util/RequestAndCompletionHandler.java create mode 100644 server-common/src/test/java/org/apache/kafka/server/util/InterBrokerSendThreadTest.java diff --git a/checkstyle/import-control-server-common.xml b/checkstyle/import-control-server-common.xml index f238bc9b8d9..350f2820968 100644 --- a/checkstyle/import-control-server-common.xml +++ b/checkstyle/import-control-server-common.xml @@ -81,6 +81,15 @@ + + + + + + + + + diff --git a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala deleted file mode 100644 index 4ac37ce6346..00000000000 --- a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package kafka.common - -import kafka.utils.Logging - -import java.util.Map.Entry -import java.util.{ArrayDeque, ArrayList, Collection, Collections, HashMap, Iterator} -import org.apache.kafka.clients.{ClientRequest, ClientResponse, KafkaClient, RequestCompletionHandler} -import org.apache.kafka.common.Node -import org.apache.kafka.common.errors.{AuthenticationException, DisconnectException} -import org.apache.kafka.common.internals.FatalExitError -import org.apache.kafka.common.requests.AbstractRequest -import org.apache.kafka.common.utils.Time -import org.apache.kafka.server.util.ShutdownableThread - -import scala.jdk.CollectionConverters._ - -/** - * Class for inter-broker send thread that utilize a non-blocking network client. - */ -abstract class InterBrokerSendThread( - name: String, - @volatile var networkClient: KafkaClient, - requestTimeoutMs: Int, - time: Time, - isInterruptible: Boolean = true -) extends ShutdownableThread(name, isInterruptible) with Logging { - - this.logIdent = logPrefix - - private val unsentRequests = new UnsentRequests - - def generateRequests(): Iterable[RequestAndCompletionHandler] - - def hasUnsentRequests: Boolean = unsentRequests.iterator().hasNext - - override def shutdown(): Unit = { - initiateShutdown() - networkClient.initiateClose() - awaitShutdown() - networkClient.close() - } - - private def drainGeneratedRequests(): Unit = { - generateRequests().foreach { request => - unsentRequests.put(request.destination, - networkClient.newClientRequest( - request.destination.idString, - request.request, - request.creationTimeMs, - true, - requestTimeoutMs, - request.handler - )) - } - } - - protected def pollOnce(maxTimeoutMs: Long): Unit = { - try { - drainGeneratedRequests() - var now = time.milliseconds() - val timeout = sendRequests(now, maxTimeoutMs) - networkClient.poll(timeout, now) - now = time.milliseconds() - checkDisconnects(now) - failExpiredRequests(now) - unsentRequests.clean() - } catch { - case _: DisconnectException if !networkClient.active() => - // DisconnectException is expected when NetworkClient#initiateClose is called - case e: FatalExitError => throw e - case t: Throwable => - error(s"unhandled exception caught in InterBrokerSendThread", t) - // rethrow any unhandled exceptions as FatalExitError so the JVM will be terminated - // as we will be in an unknown state with potentially some requests dropped and not - // being able to make progress. Known and expected Errors should have been appropriately - // dealt with already. - throw new FatalExitError() - } - } - - override def doWork(): Unit = { - pollOnce(Long.MaxValue) - } - - private def sendRequests(now: Long, maxTimeoutMs: Long): Long = { - var pollTimeout = maxTimeoutMs - for (node <- unsentRequests.nodes.asScala) { - val requestIterator = unsentRequests.requestIterator(node) - while (requestIterator.hasNext) { - val request = requestIterator.next - if (networkClient.ready(node, now)) { - networkClient.send(request, now) - requestIterator.remove() - } else - pollTimeout = Math.min(pollTimeout, networkClient.connectionDelay(node, now)) - } - } - pollTimeout - } - - private def checkDisconnects(now: Long): Unit = { - // any disconnects affecting requests that have already been transmitted will be handled - // by NetworkClient, so we just need to check whether connections for any of the unsent - // requests have been disconnected; if they have, then we complete the corresponding future - // and set the disconnect flag in the ClientResponse - val iterator = unsentRequests.iterator() - while (iterator.hasNext) { - val entry = iterator.next - val (node, requests) = (entry.getKey, entry.getValue) - if (!requests.isEmpty && networkClient.connectionFailed(node)) { - iterator.remove() - for (request <- requests.asScala) { - val authenticationException = networkClient.authenticationException(node) - if (authenticationException != null) - error(s"Failed to send the following request due to authentication error: $request") - completeWithDisconnect(request, now, authenticationException) - } - } - } - } - - private def failExpiredRequests(now: Long): Unit = { - // clear all expired unsent requests - val timedOutRequests = unsentRequests.removeAllTimedOut(now) - for (request <- timedOutRequests.asScala) { - debug(s"Failed to send the following request after ${request.requestTimeoutMs} ms: $request") - completeWithDisconnect(request, now, null) - } - } - - def completeWithDisconnect(request: ClientRequest, - now: Long, - authenticationException: AuthenticationException): Unit = { - val handler = request.callback - handler.onComplete(new ClientResponse(request.makeHeader(request.requestBuilder().latestAllowedVersion()), - handler, request.destination, now /* createdTimeMs */ , now /* receivedTimeMs */ , true /* disconnected */ , - null /* versionMismatch */ , authenticationException, null)) - } - - def wakeup(): Unit = networkClient.wakeup() -} - -case class RequestAndCompletionHandler( - creationTimeMs: Long, - destination: Node, - request: AbstractRequest.Builder[_ <: AbstractRequest], - handler: RequestCompletionHandler -) - -private class UnsentRequests { - private val unsent = new HashMap[Node, ArrayDeque[ClientRequest]] - - def put(node: Node, request: ClientRequest): Unit = { - var requests = unsent.get(node) - if (requests == null) { - requests = new ArrayDeque[ClientRequest] - unsent.put(node, requests) - } - requests.add(request) - } - - def removeAllTimedOut(now: Long): Collection[ClientRequest] = { - val expiredRequests = new ArrayList[ClientRequest] - for (requests <- unsent.values.asScala) { - val requestIterator = requests.iterator - var foundExpiredRequest = false - while (requestIterator.hasNext && !foundExpiredRequest) { - val request = requestIterator.next - val elapsedMs = Math.max(0, now - request.createdTimeMs) - if (elapsedMs > request.requestTimeoutMs) { - expiredRequests.add(request) - requestIterator.remove() - foundExpiredRequest = true - } - } - } - expiredRequests - } - - def clean(): Unit = { - val iterator = unsent.values.iterator - while (iterator.hasNext) { - val requests = iterator.next - if (requests.isEmpty) - iterator.remove() - } - } - - def iterator(): Iterator[Entry[Node, ArrayDeque[ClientRequest]]] = { - unsent.entrySet().iterator() - } - - def requestIterator(node: Node): Iterator[ClientRequest] = { - val requests = unsent.get(node) - if (requests == null) - Collections.emptyIterator[ClientRequest] - else - requests.iterator - } - - def nodes: java.util.Set[Node] = unsent.keySet -} diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala index 049aae97c7e..647db68fa47 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala @@ -19,8 +19,6 @@ package kafka.coordinator.transaction import java.util import java.util.concurrent.{BlockingQueue, ConcurrentHashMap, LinkedBlockingQueue} - -import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} import kafka.server.{KafkaConfig, MetadataCache, RequestLocal} import kafka.utils.Implicits._ import kafka.utils.{CoreUtils, Logging} @@ -35,6 +33,7 @@ import org.apache.kafka.common.utils.{LogContext, Time} import org.apache.kafka.common.{Node, Reconfigurable, TopicPartition} import org.apache.kafka.server.common.MetadataVersion.IBP_2_8_IV0 import org.apache.kafka.server.metrics.KafkaMetricsGroup +import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler} import scala.collection.{concurrent, immutable} import scala.jdk.CollectionConverters._ @@ -183,7 +182,7 @@ class TransactionMarkerChannelManager( } def retryLogAppends(): Unit = { - val txnLogAppendRetries: java.util.List[PendingCompleteTxn] = new util.ArrayList[PendingCompleteTxn]() + val txnLogAppendRetries: util.List[PendingCompleteTxn] = new util.ArrayList[PendingCompleteTxn]() txnLogAppendRetryQueue.drainTo(txnLogAppendRetries) txnLogAppendRetries.forEach { txnLogAppend => debug(s"Retry appending $txnLogAppend transaction log") @@ -191,9 +190,9 @@ class TransactionMarkerChannelManager( } } - override def generateRequests(): Iterable[RequestAndCompletionHandler] = { + override def generateRequests(): util.Collection[RequestAndCompletionHandler] = { retryLogAppends() - val txnIdAndMarkerEntries: java.util.List[TxnIdAndMarkerEntry] = new util.ArrayList[TxnIdAndMarkerEntry]() + val txnIdAndMarkerEntries: util.List[TxnIdAndMarkerEntry] = new util.ArrayList[TxnIdAndMarkerEntry]() markersQueueForUnknownBroker.forEachTxnTopicPartition { case (_, queue) => queue.drainTo(txnIdAndMarkerEntries) } @@ -221,13 +220,13 @@ class TransactionMarkerChannelManager( val requestCompletionHandler = new TransactionMarkerRequestCompletionHandler(node.id, txnStateManager, this, entries) val request = new WriteTxnMarkersRequest.Builder(writeTxnMarkersRequestVersion, markersToSend) - RequestAndCompletionHandler( + new RequestAndCompletionHandler( currentTimeMs, node, request, requestCompletionHandler ) - } + }.asJavaCollection } private def writeTxnCompletion(pendingCompleteTxn: PendingCompleteTxn): Unit = { diff --git a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala index 0d86e257932..27a489b72a0 100644 --- a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala +++ b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala @@ -16,7 +16,6 @@ */ package kafka.raft -import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} import kafka.utils.Logging import org.apache.kafka.clients.{ClientResponse, KafkaClient} import org.apache.kafka.common.Node @@ -26,7 +25,9 @@ import org.apache.kafka.common.requests._ import org.apache.kafka.common.utils.Time import org.apache.kafka.raft.RaftConfig.InetAddressSpec import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, RaftUtil} +import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler} +import java.util import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable @@ -67,17 +68,17 @@ private[raft] class RaftSendThread( ) { private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]() - def generateRequests(): Iterable[RequestAndCompletionHandler] = { - val buffer = mutable.Buffer[RequestAndCompletionHandler]() + def generateRequests(): util.Collection[RequestAndCompletionHandler] = { + val list = new util.ArrayList[RequestAndCompletionHandler]() while (true) { val request = queue.poll() if (request == null) { - return buffer + return list } else { - buffer += request + list.add(request) } } - buffer + list } def sendRequest(request: RequestAndCompletionHandler): Unit = { @@ -142,11 +143,11 @@ class KafkaNetworkChannel( endpoints.get(request.destinationId) match { case Some(node) => - requestThread.sendRequest(RequestAndCompletionHandler( + requestThread.sendRequest(new RequestAndCompletionHandler( request.createdTimeMs, - destination = node, - request = buildRequest(request.data), - handler = onComplete + node, + buildRequest(request.data), + onComplete )) case None => diff --git a/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala b/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala index 4fd5a29b4a8..cbf981a76dd 100644 --- a/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala +++ b/core/src/main/scala/kafka/server/AddPartitionsToTxnManager.scala @@ -17,14 +17,16 @@ package kafka.server -import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} +import kafka.utils.Logging import org.apache.kafka.clients.{ClientResponse, NetworkClient, RequestCompletionHandler} import org.apache.kafka.common.{Node, TopicPartition} import org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTransaction, AddPartitionsToTxnTransactionCollection} import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.requests.{AddPartitionsToTxnRequest, AddPartitionsToTxnResponse} import org.apache.kafka.common.utils.Time +import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler} +import java.util import scala.collection.mutable object AddPartitionsToTxnManager { @@ -37,7 +39,10 @@ class TransactionDataAndCallbacks(val transactionData: AddPartitionsToTxnTransac class AddPartitionsToTxnManager(config: KafkaConfig, client: NetworkClient, time: Time) - extends InterBrokerSendThread("AddPartitionsToTxnSenderThread-" + config.brokerId, client, config.requestTimeoutMs, time) { + extends InterBrokerSendThread("AddPartitionsToTxnSenderThread-" + config.brokerId, client, config.requestTimeoutMs, time) + with Logging { + + this.logIdent = logPrefix private val inflightNodes = mutable.HashSet[Node]() private val nodesToTransactions = mutable.Map[Node, TransactionDataAndCallbacks]() @@ -157,20 +162,20 @@ class AddPartitionsToTxnManager(config: KafkaConfig, client: NetworkClient, time } } - override def generateRequests(): Iterable[RequestAndCompletionHandler] = { + override def generateRequests(): util.Collection[RequestAndCompletionHandler] = { // build and add requests to queue - val buffer = mutable.Buffer[RequestAndCompletionHandler]() + val list = new util.ArrayList[RequestAndCompletionHandler]() val currentTimeMs = time.milliseconds() val removedNodes = mutable.Set[Node]() nodesToTransactions.synchronized { nodesToTransactions.foreach { case (node, transactionDataAndCallbacks) => if (!inflightNodes.contains(node)) { - buffer += RequestAndCompletionHandler( + list.add(new RequestAndCompletionHandler( currentTimeMs, node, AddPartitionsToTxnRequest.Builder.forBroker(transactionDataAndCallbacks.transactionData), new AddPartitionsToTxnHandler(node, transactionDataAndCallbacks) - ) + )) removedNodes.add(node) } @@ -180,7 +185,7 @@ class AddPartitionsToTxnManager(config: KafkaConfig, client: NetworkClient, time nodesToTransactions.remove(node) } } - buffer + list } } diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala index 2d259c8a2f4..bbf0792fc37 100644 --- a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala +++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala @@ -19,7 +19,6 @@ package kafka.server import java.util.concurrent.LinkedBlockingDeque import java.util.concurrent.atomic.AtomicReference -import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} import kafka.raft.RaftManager import kafka.server.metadata.ZkMetadataCache import kafka.utils.Logging @@ -33,7 +32,9 @@ import org.apache.kafka.common.security.JaasContext import org.apache.kafka.common.security.auth.SecurityProtocol import org.apache.kafka.common.utils.{LogContext, Time} import org.apache.kafka.server.common.ApiMessageAndVersion +import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler} +import java.util import scala.collection.Seq import scala.compat.java8.OptionConverters._ import scala.jdk.CollectionConverters._ @@ -306,8 +307,10 @@ class BrokerToControllerRequestThread( initialNetworkClient, Math.min(Int.MaxValue, Math.min(config.controllerSocketTimeoutMs, retryTimeoutMs)).toInt, time, - isInterruptible = false -) { + false +) with Logging { + + this.logIdent = logPrefix private def maybeResetNetworkClient(controllerInformation: ControllerInformation): Unit = { if (isNetworkClientForZkController != controllerInformation.isZkController) { @@ -354,7 +357,7 @@ class BrokerToControllerRequestThread( requestQueue.size } - override def generateRequests(): Iterable[RequestAndCompletionHandler] = { + override def generateRequests(): util.Collection[RequestAndCompletionHandler] = { val currentTimeMs = time.milliseconds() val requestIter = requestQueue.iterator() while (requestIter.hasNext) { @@ -366,16 +369,16 @@ class BrokerToControllerRequestThread( val controllerAddress = activeControllerAddress() if (controllerAddress.isDefined) { requestIter.remove() - return Some(RequestAndCompletionHandler( + return util.Collections.singletonList(new RequestAndCompletionHandler( time.milliseconds(), controllerAddress.get, request.request, - handleResponse(request) + response => handleResponse(request)(response) )) } } } - None + util.Collections.emptyList() } private[server] def handleResponse(queueItem: BrokerToControllerQueueItem)(response: ClientResponse): Unit = { @@ -426,7 +429,7 @@ class BrokerToControllerRequestThread( case None => // need to backoff to avoid tight loops debug("No controller provided, retrying after backoff") - super.pollOnce(maxTimeoutMs = 100) + super.pollOnce(100) } } } diff --git a/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala b/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala deleted file mode 100644 index 9ca36a3d8df..00000000000 --- a/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package kafka.common - -import org.apache.kafka.clients.{ClientRequest, ClientResponse, NetworkClient, RequestCompletionHandler} -import org.apache.kafka.common.Node -import org.apache.kafka.common.errors.{AuthenticationException, DisconnectException} -import org.apache.kafka.common.protocol.ApiKeys -import org.apache.kafka.common.requests.AbstractRequest -import org.apache.kafka.server.util.MockTime -import org.junit.jupiter.api.Assertions._ -import org.junit.jupiter.api.Test -import org.mockito.ArgumentMatchers.{any, anyLong, same} -import org.mockito.ArgumentMatchers -import org.mockito.Mockito.{mock, verify, when} - -import java.util -import scala.collection.mutable - -class InterBrokerSendThreadTest { - private val time = new MockTime() - private val networkClient: NetworkClient = mock(classOf[NetworkClient]) - private val completionHandler = new StubCompletionHandler - private val requestTimeoutMs = 1000 - - class TestInterBrokerSendThread(networkClient: NetworkClient = networkClient, - exceptionCallback: Throwable => Unit = t => throw t) - extends InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) { - private val queue = mutable.Queue[RequestAndCompletionHandler]() - - def enqueue(request: RequestAndCompletionHandler): Unit = { - queue += request - } - - override def generateRequests(): Iterable[RequestAndCompletionHandler] = { - if (queue.isEmpty) { - None - } else { - Some(queue.dequeue()) - } - } - override def pollOnce(maxTimeoutMs: Long): Unit = { - try super.pollOnce(maxTimeoutMs) - catch { - case e: Throwable => exceptionCallback(e) - } - } - - } - - @Test - def shutdownThreadShouldNotCauseException(): Unit = { - // InterBrokerSendThread#shutdown calls NetworkClient#initiateClose first so NetworkClient#poll - // can throw DisconnectException when thread is running - when(networkClient.poll(anyLong(), anyLong())).thenThrow(new DisconnectException()) - var exception: Throwable = null - val thread = new TestInterBrokerSendThread(networkClient, e => exception = e) - thread.shutdown() - thread.pollOnce(100) - - verify(networkClient).poll(anyLong(), anyLong()) - assertNull(exception) - } - - @Test - def shouldNotSendAnythingWhenNoRequests(): Unit = { - val sendThread = new TestInterBrokerSendThread() - - // poll is always called but there should be no further invocations on NetworkClient - when(networkClient.poll(anyLong(), anyLong())) - .thenReturn(new util.ArrayList[ClientResponse]()) - - sendThread.doWork() - - verify(networkClient).poll(anyLong(), anyLong()) - assertFalse(completionHandler.executedWithDisconnectedResponse) - } - - @Test - def shouldCreateClientRequestAndSendWhenNodeIsReady(): Unit = { - val request = new StubRequestBuilder() - val node = new Node(1, "", 8080) - val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler) - val sendThread = new TestInterBrokerSendThread() - - val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler) - - when(networkClient.newClientRequest( - ArgumentMatchers.eq("1"), - same(handler.request), - anyLong(), - ArgumentMatchers.eq(true), - ArgumentMatchers.eq(requestTimeoutMs), - same(handler.handler))) - .thenReturn(clientRequest) - - when(networkClient.ready(node, time.milliseconds())) - .thenReturn(true) - - when(networkClient.poll(anyLong(), anyLong())) - .thenReturn(new util.ArrayList[ClientResponse]()) - - sendThread.enqueue(handler) - sendThread.doWork() - - verify(networkClient).newClientRequest( - ArgumentMatchers.eq("1"), - same(handler.request), - anyLong(), - ArgumentMatchers.eq(true), - ArgumentMatchers.eq(requestTimeoutMs), - same(handler.handler)) - verify(networkClient).ready(any[Node], anyLong()) - verify(networkClient).send(same(clientRequest), anyLong()) - verify(networkClient).poll(anyLong(), anyLong()) - assertFalse(completionHandler.executedWithDisconnectedResponse) - } - - @Test - def shouldCallCompletionHandlerWithDisconnectedResponseWhenNodeNotReady(): Unit = { - val request = new StubRequestBuilder - val node = new Node(1, "", 8080) - val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler) - val sendThread = new TestInterBrokerSendThread() - - val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler) - - when(networkClient.newClientRequest( - ArgumentMatchers.eq("1"), - same(handler.request), - anyLong(), - ArgumentMatchers.eq(true), - ArgumentMatchers.eq(requestTimeoutMs), - same(handler.handler))) - .thenReturn(clientRequest) - - when(networkClient.ready(node, time.milliseconds())) - .thenReturn(false) - - when(networkClient.connectionDelay(any[Node], anyLong())) - .thenReturn(0) - - when(networkClient.poll(anyLong(), anyLong())) - .thenReturn(new util.ArrayList[ClientResponse]()) - - when(networkClient.connectionFailed(node)) - .thenReturn(true) - - when(networkClient.authenticationException(node)) - .thenReturn(new AuthenticationException("")) - - sendThread.enqueue(handler) - sendThread.doWork() - - verify(networkClient).newClientRequest( - ArgumentMatchers.eq("1"), - same(handler.request), - anyLong, - ArgumentMatchers.eq(true), - ArgumentMatchers.eq(requestTimeoutMs), - same(handler.handler)) - verify(networkClient).ready(any[Node], anyLong) - verify(networkClient).connectionDelay(any[Node], anyLong) - verify(networkClient).poll(anyLong, anyLong) - verify(networkClient).connectionFailed(any[Node]) - verify(networkClient).authenticationException(any[Node]) - assertTrue(completionHandler.executedWithDisconnectedResponse) - } - - @Test - def testFailingExpiredRequests(): Unit = { - val request = new StubRequestBuilder() - val node = new Node(1, "", 8080) - val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler) - val sendThread = new TestInterBrokerSendThread() - - val clientRequest = new ClientRequest("dest", - request, - 0, - "1", - time.milliseconds(), - true, - requestTimeoutMs, - handler.handler) - time.sleep(1500) - - when(networkClient.newClientRequest( - ArgumentMatchers.eq("1"), - same(handler.request), - ArgumentMatchers.eq(handler.creationTimeMs), - ArgumentMatchers.eq(true), - ArgumentMatchers.eq(requestTimeoutMs), - same(handler.handler))) - .thenReturn(clientRequest) - - // make the node unready so the request is not cleared - when(networkClient.ready(node, time.milliseconds())) - .thenReturn(false) - - when(networkClient.connectionDelay(any[Node], anyLong())) - .thenReturn(0) - - when(networkClient.poll(anyLong(), anyLong())) - .thenReturn(new util.ArrayList[ClientResponse]()) - - // rule out disconnects so the request stays for the expiry check - when(networkClient.connectionFailed(node)) - .thenReturn(false) - - sendThread.enqueue(handler) - sendThread.doWork() - - verify(networkClient).newClientRequest( - ArgumentMatchers.eq("1"), - same(handler.request), - ArgumentMatchers.eq(handler.creationTimeMs), - ArgumentMatchers.eq(true), - ArgumentMatchers.eq(requestTimeoutMs), - same(handler.handler)) - verify(networkClient).ready(any[Node], anyLong) - verify(networkClient).connectionDelay(any[Node], anyLong) - verify(networkClient).poll(anyLong, anyLong) - verify(networkClient).connectionFailed(any[Node]) - - assertFalse(sendThread.hasUnsentRequests) - assertTrue(completionHandler.executedWithDisconnectedResponse) - } - - private class StubRequestBuilder extends AbstractRequest.Builder(ApiKeys.END_TXN) { - override def build(version: Short): Nothing = ??? - } - - private class StubCompletionHandler extends RequestCompletionHandler { - var executedWithDisconnectedResponse = false - var response: ClientResponse = _ - override def onComplete(response: ClientResponse): Unit = { - this.executedWithDisconnectedResponse = response.wasDisconnected() - this.response = response - } - } - -} diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala index c458ac191c0..9a0d8143766 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala @@ -384,7 +384,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren new WriteTxnMarkersResponse(pidErrorMap) } synchronized { - txnMarkerChannelManager.generateRequests().foreach { requestAndHandler => + txnMarkerChannelManager.generateRequests().asScala.foreach { requestAndHandler => val request = requestAndHandler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build() val response = createResponse(request) requestAndHandler.handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1), diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala index a1598099055..d966b3b43e0 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala @@ -20,8 +20,6 @@ import java.util import java.util.Arrays.asList import java.util.Collections import java.util.concurrent.{Callable, Executors, Future} - -import kafka.common.RequestAndCompletionHandler import kafka.server.{KafkaConfig, MetadataCache} import kafka.utils.TestUtils import org.apache.kafka.clients.{ClientResponse, NetworkClient} @@ -31,6 +29,7 @@ import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, Write import org.apache.kafka.common.utils.MockTime import org.apache.kafka.common.{Node, TopicPartition} import org.apache.kafka.server.metrics.KafkaYammerMetrics +import org.apache.kafka.server.util.RequestAndCompletionHandler import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test import org.mockito.ArgumentMatchers.any @@ -129,7 +128,7 @@ class TransactionMarkerChannelManagerTest { response) TestUtils.waitUntilTrue(() => { - val requests = channelManager.generateRequests() + val requests = channelManager.generateRequests().asScala if (requests.nonEmpty) { assertEquals(1, requests.size) val request = requests.head @@ -194,7 +193,7 @@ class TransactionMarkerChannelManagerTest { val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition2)))).build() - val requests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => + val requests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().asScala.map { handler => (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) }.toMap @@ -264,13 +263,13 @@ class TransactionMarkerChannelManagerTest { val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition2)))).build() - val firstDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => + val firstDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().asScala.map { handler => (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) }.toMap assertEquals(Map(broker2 -> expectedBroker2Request), firstDrainedRequests) - val secondDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => + val secondDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().asScala.map { handler => (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) }.toMap @@ -345,7 +344,7 @@ class TransactionMarkerChannelManagerTest { channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) - val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests().asScala val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) for (requestAndHandler <- requestAndHandlers) { @@ -398,7 +397,7 @@ class TransactionMarkerChannelManagerTest { channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) - val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests().asScala val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) for (requestAndHandler <- requestAndHandlers) { @@ -452,7 +451,7 @@ class TransactionMarkerChannelManagerTest { channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) - val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests().asScala val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) for (requestAndHandler <- requestAndHandlers) { diff --git a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala index a2c7fbcd232..01ced6ab5d4 100644 --- a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnManagerTest.scala @@ -17,7 +17,6 @@ package unit.kafka.server -import kafka.common.RequestAndCompletionHandler import kafka.server.{AddPartitionsToTxnManager, KafkaConfig} import kafka.utils.TestUtils import org.apache.kafka.clients.{ClientResponse, NetworkClient} @@ -29,6 +28,7 @@ import org.apache.kafka.common.{Node, TopicPartition} import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.requests.{AbstractResponse, AddPartitionsToTxnRequest, AddPartitionsToTxnResponse} import org.apache.kafka.common.utils.MockTime +import org.apache.kafka.server.util.RequestAndCompletionHandler import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} import org.mockito.Mockito.mock @@ -110,7 +110,7 @@ class AddPartitionsToTxnManagerTest { addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId1, producerId1, producerEpoch = 0), setErrors(transaction1RetryWithOldEpochErrors)) assertEquals(expectedEpochErrors, transaction1RetryWithOldEpochErrors) - val requestsAndHandlers = addPartitionsToTxnManager.generateRequests() + val requestsAndHandlers = addPartitionsToTxnManager.generateRequests().asScala requestsAndHandlers.foreach { requestAndHandler => if (requestAndHandler.destination == node0) { assertEquals(time.milliseconds(), requestAndHandler.creationTimeMs) @@ -130,7 +130,7 @@ class AddPartitionsToTxnManagerTest { addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId1, producerId1), setErrors(transactionErrors)) addPartitionsToTxnManager.addTxnData(node1, transactionData(transactionalId2, producerId2), setErrors(transactionErrors)) - val requestsAndHandlers = addPartitionsToTxnManager.generateRequests() + val requestsAndHandlers = addPartitionsToTxnManager.generateRequests().asScala assertEquals(2, requestsAndHandlers.size) // Note: handlers are tested in testAddPartitionsToTxnHandlerErrorHandling requestsAndHandlers.foreach{ requestAndHandler => @@ -147,7 +147,7 @@ class AddPartitionsToTxnManagerTest { // Test creationTimeMs increases too. time.sleep(1000) - val requestsAndHandlers2 = addPartitionsToTxnManager.generateRequests() + val requestsAndHandlers2 = addPartitionsToTxnManager.generateRequests().asScala // The request for node1 should not be added because one request is already inflight. assertEquals(1, requestsAndHandlers2.size) requestsAndHandlers2.foreach { requestAndHandler => @@ -156,7 +156,7 @@ class AddPartitionsToTxnManagerTest { // Complete the request for node1 so the new one can go through. requestsAndHandlers.filter(_.destination == node1).head.handler.onComplete(authenticationErrorResponse) - val requestsAndHandlers3 = addPartitionsToTxnManager.generateRequests() + val requestsAndHandlers3 = addPartitionsToTxnManager.generateRequests().asScala assertEquals(1, requestsAndHandlers3.size) requestsAndHandlers3.foreach { requestAndHandler => verifyRequest(node1, transactionalId2, producerId2, requestAndHandler) @@ -238,7 +238,7 @@ class AddPartitionsToTxnManagerTest { } private def receiveResponse(response: ClientResponse): Unit = { - addPartitionsToTxnManager.generateRequests().head.handler.onComplete(response) + addPartitionsToTxnManager.generateRequests().asScala.head.handler.onComplete(response) } private def verifyRequest(expectedDestination: Node, transactionalId: String, producerId: Long, requestAndHandler: RequestAndCompletionHandler): Unit = { diff --git a/server-common/src/main/java/org/apache/kafka/server/util/InterBrokerSendThread.java b/server-common/src/main/java/org/apache/kafka/server/util/InterBrokerSendThread.java new file mode 100644 index 00000000000..227c7e0064a --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/server/util/InterBrokerSendThread.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.util; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import org.apache.kafka.clients.ClientRequest; +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.KafkaClient; +import org.apache.kafka.clients.RequestCompletionHandler; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.internals.FatalExitError; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; + +/** + * An inter-broker send thread that utilizes a non-blocking network client. + */ +public abstract class InterBrokerSendThread extends ShutdownableThread { + + protected volatile KafkaClient networkClient; + + private final int requestTimeoutMs; + private final Time time; + private final UnsentRequests unsentRequests; + + public InterBrokerSendThread( + String name, + KafkaClient networkClient, + int requestTimeoutMs, + Time time + ) { + this(name, networkClient, requestTimeoutMs, time, true); + } + + public InterBrokerSendThread( + String name, + KafkaClient networkClient, + int requestTimeoutMs, + Time time, + boolean isInterruptible + ) { + super(name, isInterruptible); + this.networkClient = networkClient; + this.requestTimeoutMs = requestTimeoutMs; + this.time = time; + this.unsentRequests = new UnsentRequests(); + } + + public abstract Collection generateRequests(); + + public boolean hasUnsentRequests() { + return unsentRequests.iterator().hasNext(); + } + + @Override + public void shutdown() throws InterruptedException { + initiateShutdown(); + networkClient.initiateClose(); + awaitShutdown(); + Utils.closeQuietly(networkClient, "InterBrokerSendThread network client"); + } + + private void drainGeneratedRequests() { + generateRequests().forEach(request -> + unsentRequests.put( + request.destination, + networkClient.newClientRequest( + request.destination.idString(), + request.request, + request.creationTimeMs, + true, + requestTimeoutMs, + request.handler + ) + ) + ); + } + + protected void pollOnce(long maxTimeoutMs) { + try { + drainGeneratedRequests(); + long now = time.milliseconds(); + final long timeout = sendRequests(now, maxTimeoutMs); + networkClient.poll(timeout, now); + now = time.milliseconds(); + checkDisconnects(now); + failExpiredRequests(now); + unsentRequests.clean(); + } catch (FatalExitError fee) { + throw fee; + } catch (Throwable t) { + if (t instanceof DisconnectException && !networkClient.active()) { + // DisconnectException is expected when NetworkClient#initiateClose is called + return; + } + log.error("unhandled exception caught in InterBrokerSendThread", t); + // rethrow any unhandled exceptions as FatalExitError so the JVM will be terminated + // as we will be in an unknown state with potentially some requests dropped and not + // being able to make progress. Known and expected Errors should have been appropriately + // dealt with already. + throw new FatalExitError(); + } + } + + @Override + public void doWork() { + pollOnce(Long.MAX_VALUE); + } + + private long sendRequests(long now, long maxTimeoutMs) { + long pollTimeout = maxTimeoutMs; + for (Node node : unsentRequests.nodes()) { + final Iterator requestIterator = unsentRequests.requestIterator(node); + while (requestIterator.hasNext()) { + final ClientRequest request = requestIterator.next(); + if (networkClient.ready(node, now)) { + networkClient.send(request, now); + requestIterator.remove(); + } else { + pollTimeout = Math.min(pollTimeout, networkClient.connectionDelay(node, now)); + } + } + } + return pollTimeout; + } + + private void checkDisconnects(long now) { + // any disconnects affecting requests that have already been transmitted will be handled + // by NetworkClient, so we just need to check whether connections for any of the unsent + // requests have been disconnected; if they have, then we complete the corresponding future + // and set the disconnect flag in the ClientResponse + final Iterator>> iterator = unsentRequests.iterator(); + while (iterator.hasNext()) { + final Map.Entry> entry = iterator.next(); + final Node node = entry.getKey(); + final ArrayDeque requests = entry.getValue(); + if (!requests.isEmpty() && networkClient.connectionFailed(node)) { + iterator.remove(); + for (ClientRequest request : requests) { + final AuthenticationException authenticationException = networkClient.authenticationException(node); + if (authenticationException != null) { + log.error("Failed to send the following request due to authentication error: {}", request); + } + completeWithDisconnect(request, now, authenticationException); + } + } + } + } + + private void failExpiredRequests(long now) { + // clear all expired unsent requests + final Collection timedOutRequests = unsentRequests.removeAllTimedOut(now); + for (ClientRequest request : timedOutRequests) { + log.debug("Failed to send the following request after {} ms: {}", request.requestTimeoutMs(), request); + completeWithDisconnect(request, now, null); + } + } + + private static void completeWithDisconnect( + ClientRequest request, + long now, + AuthenticationException authenticationException + ) { + final RequestCompletionHandler handler = request.callback(); + handler.onComplete( + new ClientResponse( + request.makeHeader(request.requestBuilder().latestAllowedVersion()), + handler, + request.destination(), + now /* createdTimeMs */, + now /* receivedTimeMs */, + true /* disconnected */, + null /* versionMismatch */, + authenticationException, + null + ) + ); + } + + public void wakeup() { + networkClient.wakeup(); + } + + private static final class UnsentRequests { + + private final Map> unsent = new HashMap<>(); + + void put(Node node, ClientRequest request) { + ArrayDeque requests = unsent.computeIfAbsent(node, n -> new ArrayDeque<>()); + requests.add(request); + } + + Collection removeAllTimedOut(long now) { + final List expiredRequests = new ArrayList<>(); + for (ArrayDeque requests : unsent.values()) { + final Iterator requestIterator = requests.iterator(); + boolean foundExpiredRequest = false; + while (requestIterator.hasNext() && !foundExpiredRequest) { + final ClientRequest request = requestIterator.next(); + final long elapsedMs = Math.max(0, now - request.createdTimeMs()); + if (elapsedMs > request.requestTimeoutMs()) { + expiredRequests.add(request); + requestIterator.remove(); + foundExpiredRequest = true; + } + } + } + return expiredRequests; + } + + void clean() { + unsent.values().removeIf(ArrayDeque::isEmpty); + } + + Iterator>> iterator() { + return unsent.entrySet().iterator(); + } + + Iterator requestIterator(Node node) { + ArrayDeque requests = unsent.get(node); + return (requests == null) ? Collections.emptyIterator() : requests.iterator(); + } + + Set nodes() { + return unsent.keySet(); + } + } +} diff --git a/server-common/src/main/java/org/apache/kafka/server/util/RequestAndCompletionHandler.java b/server-common/src/main/java/org/apache/kafka/server/util/RequestAndCompletionHandler.java new file mode 100644 index 00000000000..da14fb5e4a4 --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/server/util/RequestAndCompletionHandler.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.util; + +import org.apache.kafka.clients.RequestCompletionHandler; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.requests.AbstractRequest; + +public final class RequestAndCompletionHandler { + + public final long creationTimeMs; + public final Node destination; + public final AbstractRequest.Builder request; + public final RequestCompletionHandler handler; + + public RequestAndCompletionHandler( + long creationTimeMs, + Node destination, + AbstractRequest.Builder request, + RequestCompletionHandler handler + ) { + this.creationTimeMs = creationTimeMs; + this.destination = destination; + this.request = request; + this.handler = handler; + } + + @Override + public String toString() { + return "RequestAndCompletionHandler(" + + "creationTimeMs=" + creationTimeMs + + ", destination=" + destination + + ", request=" + request + + ", handler=" + handler + + ')'; + } +} diff --git a/server-common/src/main/java/org/apache/kafka/server/util/ShutdownableThread.java b/server-common/src/main/java/org/apache/kafka/server/util/ShutdownableThread.java index e33f64b1c7b..4ef727d3040 100644 --- a/server-common/src/main/java/org/apache/kafka/server/util/ShutdownableThread.java +++ b/server-common/src/main/java/org/apache/kafka/server/util/ShutdownableThread.java @@ -29,7 +29,7 @@ public abstract class ShutdownableThread extends Thread { public final String logPrefix; - private final Logger log; + protected final Logger log; private final boolean isInterruptible; diff --git a/server-common/src/test/java/org/apache/kafka/server/util/InterBrokerSendThreadTest.java b/server-common/src/test/java/org/apache/kafka/server/util/InterBrokerSendThreadTest.java new file mode 100644 index 00000000000..c9b638d88f4 --- /dev/null +++ b/server-common/src/test/java/org/apache/kafka/server/util/InterBrokerSendThreadTest.java @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.util; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Collection; +import java.util.Collections; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import org.apache.kafka.clients.ClientRequest; +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.KafkaClient; +import org.apache.kafka.clients.RequestCompletionHandler; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.internals.FatalExitError; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatchers; + +public class InterBrokerSendThreadTest { + + private final Time time = new MockTime(); + private final KafkaClient networkClient = mock(KafkaClient.class); + private final StubCompletionHandler completionHandler = new StubCompletionHandler(); + private final int requestTimeoutMs = 1000; + + class TestInterBrokerSendThread extends InterBrokerSendThread { + + private final Consumer exceptionCallback; + private final Queue queue = new ArrayDeque<>(); + + TestInterBrokerSendThread() { + this( + InterBrokerSendThreadTest.this.networkClient, + t -> { + throw (t instanceof RuntimeException) + ? ((RuntimeException) t) + : new RuntimeException(t); + }); + } + + TestInterBrokerSendThread(KafkaClient networkClient, Consumer exceptionCallback) { + super("name", networkClient, requestTimeoutMs, time); + this.exceptionCallback = exceptionCallback; + } + + void enqueue(RequestAndCompletionHandler request) { + queue.offer(request); + } + + @Override + public Collection generateRequests() { + return queue.isEmpty() ? Collections.emptyList() : Collections.singletonList(queue.poll()); + } + + @Override + protected void pollOnce(long maxTimeoutMs) { + try { + super.pollOnce(maxTimeoutMs); + } catch (Throwable t) { + exceptionCallback.accept(t); + } + } + } + + @Test + public void testShutdownThreadShouldNotCauseException() throws InterruptedException, IOException { + // InterBrokerSendThread#shutdown calls NetworkClient#initiateClose first so NetworkClient#poll + // can throw DisconnectException when thread is running + when(networkClient.poll(anyLong(), anyLong())).thenThrow(new DisconnectException()); + when(networkClient.active()).thenReturn(false); + + AtomicReference exception = new AtomicReference<>(); + final InterBrokerSendThread thread = + new TestInterBrokerSendThread(networkClient, exception::getAndSet); + thread.shutdown(); + thread.pollOnce(100); + + verify(networkClient).poll(anyLong(), anyLong()); + verify(networkClient).initiateClose(); + verify(networkClient).close(); + verify(networkClient).active(); + verifyNoMoreInteractions(networkClient); + + assertNull(exception.get()); + } + + @Test + public void testDisconnectWithoutShutdownShouldCauseException() { + DisconnectException de = new DisconnectException(); + when(networkClient.poll(anyLong(), anyLong())).thenThrow(de); + when(networkClient.active()).thenReturn(true); + + AtomicReference throwable = new AtomicReference<>(); + final InterBrokerSendThread thread = + new TestInterBrokerSendThread(networkClient, throwable::getAndSet); + thread.pollOnce(100); + + verify(networkClient).poll(anyLong(), anyLong()); + verify(networkClient).active(); + verifyNoMoreInteractions(networkClient); + + Throwable thrown = throwable.get(); + assertNotNull(thrown); + assertTrue(thrown instanceof FatalExitError); + } + + @Test + public void testShouldNotSendAnythingWhenNoRequests() { + final InterBrokerSendThread sendThread = new TestInterBrokerSendThread(); + + // poll is always called but there should be no further invocations on NetworkClient + when(networkClient.poll(anyLong(), anyLong())).thenReturn(Collections.emptyList()); + + sendThread.doWork(); + + verify(networkClient).poll(anyLong(), anyLong()); + verifyNoMoreInteractions(networkClient); + + assertFalse(completionHandler.executedWithDisconnectedResponse); + } + + @Test + public void testShouldCreateClientRequestAndSendWhenNodeIsReady() { + final AbstractRequest.Builder request = new StubRequestBuilder<>(); + final Node node = new Node(1, "", 8080); + final RequestAndCompletionHandler handler = + new RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler); + final TestInterBrokerSendThread sendThread = new TestInterBrokerSendThread(); + + final ClientRequest clientRequest = + new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler); + + when(networkClient.newClientRequest( + ArgumentMatchers.eq("1"), + same(handler.request), + anyLong(), + ArgumentMatchers.eq(true), + ArgumentMatchers.eq(requestTimeoutMs), + same(handler.handler) + )).thenReturn(clientRequest); + + when(networkClient.ready(node, time.milliseconds())).thenReturn(true); + + when(networkClient.poll(anyLong(), anyLong())).thenReturn(Collections.emptyList()); + + sendThread.enqueue(handler); + sendThread.doWork(); + + verify(networkClient) + .newClientRequest( + ArgumentMatchers.eq("1"), + same(handler.request), + anyLong(), + ArgumentMatchers.eq(true), + ArgumentMatchers.eq(requestTimeoutMs), + same(handler.handler)); + verify(networkClient).ready(any(), anyLong()); + verify(networkClient).send(same(clientRequest), anyLong()); + verify(networkClient).poll(anyLong(), anyLong()); + verifyNoMoreInteractions(networkClient); + + assertFalse(completionHandler.executedWithDisconnectedResponse); + } + + @Test + public void testShouldCallCompletionHandlerWithDisconnectedResponseWhenNodeNotReady() { + final AbstractRequest.Builder request = new StubRequestBuilder<>(); + final Node node = new Node(1, "", 8080); + final RequestAndCompletionHandler handler = + new RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler); + final TestInterBrokerSendThread sendThread = new TestInterBrokerSendThread(); + + final ClientRequest clientRequest = + new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler); + + when(networkClient.newClientRequest( + ArgumentMatchers.eq("1"), + same(handler.request), + anyLong(), + ArgumentMatchers.eq(true), + ArgumentMatchers.eq(requestTimeoutMs), + same(handler.handler) + )).thenReturn(clientRequest); + + when(networkClient.ready(node, time.milliseconds())).thenReturn(false); + + when(networkClient.connectionDelay(any(), anyLong())).thenReturn(0L); + + when(networkClient.poll(anyLong(), anyLong())).thenReturn(Collections.emptyList()); + + when(networkClient.connectionFailed(node)).thenReturn(true); + + when(networkClient.authenticationException(node)).thenReturn(new AuthenticationException("")); + + sendThread.enqueue(handler); + sendThread.doWork(); + + verify(networkClient) + .newClientRequest( + ArgumentMatchers.eq("1"), + same(handler.request), + anyLong(), + ArgumentMatchers.eq(true), + ArgumentMatchers.eq(requestTimeoutMs), + same(handler.handler)); + verify(networkClient).ready(any(), anyLong()); + verify(networkClient).connectionDelay(any(), anyLong()); + verify(networkClient).poll(anyLong(), anyLong()); + verify(networkClient).connectionFailed(any()); + verify(networkClient).authenticationException(any()); + verifyNoMoreInteractions(networkClient); + + assertTrue(completionHandler.executedWithDisconnectedResponse); + } + + @Test + public void testFailingExpiredRequests() { + final AbstractRequest.Builder request = new StubRequestBuilder<>(); + final Node node = new Node(1, "", 8080); + final RequestAndCompletionHandler handler = + new RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler); + final TestInterBrokerSendThread sendThread = new TestInterBrokerSendThread(); + + final ClientRequest clientRequest = + new ClientRequest( + "dest", request, 0, "1", time.milliseconds(), true, requestTimeoutMs, handler.handler); + time.sleep(1500L); + + when(networkClient.newClientRequest( + ArgumentMatchers.eq("1"), + same(handler.request), + ArgumentMatchers.eq(handler.creationTimeMs), + ArgumentMatchers.eq(true), + ArgumentMatchers.eq(requestTimeoutMs), + same(handler.handler) + )).thenReturn(clientRequest); + + // make the node unready so the request is not cleared + when(networkClient.ready(node, time.milliseconds())).thenReturn(false); + + when(networkClient.connectionDelay(any(), anyLong())).thenReturn(0L); + + when(networkClient.poll(anyLong(), anyLong())).thenReturn(Collections.emptyList()); + + // rule out disconnects so the request stays for the expiry check + when(networkClient.connectionFailed(node)).thenReturn(false); + + sendThread.enqueue(handler); + sendThread.doWork(); + + verify(networkClient) + .newClientRequest( + ArgumentMatchers.eq("1"), + same(handler.request), + ArgumentMatchers.eq(handler.creationTimeMs), + ArgumentMatchers.eq(true), + ArgumentMatchers.eq(requestTimeoutMs), + same(handler.handler)); + verify(networkClient).ready(any(), anyLong()); + verify(networkClient).connectionDelay(any(), anyLong()); + verify(networkClient).poll(anyLong(), anyLong()); + verify(networkClient).connectionFailed(any()); + verifyNoMoreInteractions(networkClient); + + assertFalse(sendThread.hasUnsentRequests()); + assertTrue(completionHandler.executedWithDisconnectedResponse); + } + + private static class StubRequestBuilder + extends AbstractRequest.Builder { + + private StubRequestBuilder() { + super(ApiKeys.END_TXN); + } + + @Override + public T build(short version) { + return null; + } + } + + private static class StubCompletionHandler implements RequestCompletionHandler { + + public boolean executedWithDisconnectedResponse = false; + ClientResponse response = null; + + @Override + public void onComplete(ClientResponse response) { + this.executedWithDisconnectedResponse = response.wasDisconnected(); + this.response = response; + } + } +}