KAFKA-15087 Move/rewrite InterBrokerSendThread to server-commons (#13856)

The Java rewrite is kept relatively close to the Scala original
to minimize potential newly introduced bugs and to make reviewing
simpler. The following details might be of note:
- The `Logging` trait moved to InterBrokerSendThread with the
rewrite of ShutdownableThread has been similarly moved to any
subclasses that currently use it. InterBrokerSendThread's own
logging has been made to use ShutdownableThread's logger which
mimics the prefix/log identifier that the trait provided.
- The case RequestAndCompletionHandler class has been made a
separate POJO class and the internal-use UnsentRequests class
has been kept as a static nested class.
- The relatively commonly used but internal (not part of the
public API) clients classes that InterBrokerSendThread relies on
have been allowlisted in the server-common import control.
- The accompanying test class has also been moved and rewritten
with one new test added and most of the pre-existing tests made
stricter.

Reviewers: David Jacot <djacot@confluent.io>
This commit is contained in:
Dimitar Dimitrov 2023-06-20 16:50:46 +02:00 committed by GitHub
parent 0e8c436c7d
commit b100f1efac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 695 additions and 523 deletions

View File

@ -81,6 +81,15 @@
<subpackage name="network">
<allow pkg="org.apache.kafka.server.authorizer" />
</subpackage>
<!-- InterBrokerSendThread uses some clients classes that are not part of the public -->
<!-- API but are still relatively common -->
<subpackage name="util">
<allow class="org.apache.kafka.clients.ClientRequest" />
<allow class="org.apache.kafka.clients.ClientResponse" />
<allow class="org.apache.kafka.clients.KafkaClient" />
<allow class="org.apache.kafka.clients.RequestCompletionHandler" />
</subpackage>
</subpackage>
</import-control>

View File

@ -1,218 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.common
import kafka.utils.Logging
import java.util.Map.Entry
import java.util.{ArrayDeque, ArrayList, Collection, Collections, HashMap, Iterator}
import org.apache.kafka.clients.{ClientRequest, ClientResponse, KafkaClient, RequestCompletionHandler}
import org.apache.kafka.common.Node
import org.apache.kafka.common.errors.{AuthenticationException, DisconnectException}
import org.apache.kafka.common.internals.FatalExitError
import org.apache.kafka.common.requests.AbstractRequest
import org.apache.kafka.common.utils.Time
import org.apache.kafka.server.util.ShutdownableThread
import scala.jdk.CollectionConverters._
/**
* Class for inter-broker send thread that utilize a non-blocking network client.
*/
abstract class InterBrokerSendThread(
name: String,
@volatile var networkClient: KafkaClient,
requestTimeoutMs: Int,
time: Time,
isInterruptible: Boolean = true
) extends ShutdownableThread(name, isInterruptible) with Logging {
this.logIdent = logPrefix
private val unsentRequests = new UnsentRequests
def generateRequests(): Iterable[RequestAndCompletionHandler]
def hasUnsentRequests: Boolean = unsentRequests.iterator().hasNext
override def shutdown(): Unit = {
initiateShutdown()
networkClient.initiateClose()
awaitShutdown()
networkClient.close()
}
private def drainGeneratedRequests(): Unit = {
generateRequests().foreach { request =>
unsentRequests.put(request.destination,
networkClient.newClientRequest(
request.destination.idString,
request.request,
request.creationTimeMs,
true,
requestTimeoutMs,
request.handler
))
}
}
protected def pollOnce(maxTimeoutMs: Long): Unit = {
try {
drainGeneratedRequests()
var now = time.milliseconds()
val timeout = sendRequests(now, maxTimeoutMs)
networkClient.poll(timeout, now)
now = time.milliseconds()
checkDisconnects(now)
failExpiredRequests(now)
unsentRequests.clean()
} catch {
case _: DisconnectException if !networkClient.active() =>
// DisconnectException is expected when NetworkClient#initiateClose is called
case e: FatalExitError => throw e
case t: Throwable =>
error(s"unhandled exception caught in InterBrokerSendThread", t)
// rethrow any unhandled exceptions as FatalExitError so the JVM will be terminated
// as we will be in an unknown state with potentially some requests dropped and not
// being able to make progress. Known and expected Errors should have been appropriately
// dealt with already.
throw new FatalExitError()
}
}
override def doWork(): Unit = {
pollOnce(Long.MaxValue)
}
private def sendRequests(now: Long, maxTimeoutMs: Long): Long = {
var pollTimeout = maxTimeoutMs
for (node <- unsentRequests.nodes.asScala) {
val requestIterator = unsentRequests.requestIterator(node)
while (requestIterator.hasNext) {
val request = requestIterator.next
if (networkClient.ready(node, now)) {
networkClient.send(request, now)
requestIterator.remove()
} else
pollTimeout = Math.min(pollTimeout, networkClient.connectionDelay(node, now))
}
}
pollTimeout
}
private def checkDisconnects(now: Long): Unit = {
// any disconnects affecting requests that have already been transmitted will be handled
// by NetworkClient, so we just need to check whether connections for any of the unsent
// requests have been disconnected; if they have, then we complete the corresponding future
// and set the disconnect flag in the ClientResponse
val iterator = unsentRequests.iterator()
while (iterator.hasNext) {
val entry = iterator.next
val (node, requests) = (entry.getKey, entry.getValue)
if (!requests.isEmpty && networkClient.connectionFailed(node)) {
iterator.remove()
for (request <- requests.asScala) {
val authenticationException = networkClient.authenticationException(node)
if (authenticationException != null)
error(s"Failed to send the following request due to authentication error: $request")
completeWithDisconnect(request, now, authenticationException)
}
}
}
}
private def failExpiredRequests(now: Long): Unit = {
// clear all expired unsent requests
val timedOutRequests = unsentRequests.removeAllTimedOut(now)
for (request <- timedOutRequests.asScala) {
debug(s"Failed to send the following request after ${request.requestTimeoutMs} ms: $request")
completeWithDisconnect(request, now, null)
}
}
def completeWithDisconnect(request: ClientRequest,
now: Long,
authenticationException: AuthenticationException): Unit = {
val handler = request.callback
handler.onComplete(new ClientResponse(request.makeHeader(request.requestBuilder().latestAllowedVersion()),
handler, request.destination, now /* createdTimeMs */ , now /* receivedTimeMs */ , true /* disconnected */ ,
null /* versionMismatch */ , authenticationException, null))
}
def wakeup(): Unit = networkClient.wakeup()
}
case class RequestAndCompletionHandler(
creationTimeMs: Long,
destination: Node,
request: AbstractRequest.Builder[_ <: AbstractRequest],
handler: RequestCompletionHandler
)
private class UnsentRequests {
private val unsent = new HashMap[Node, ArrayDeque[ClientRequest]]
def put(node: Node, request: ClientRequest): Unit = {
var requests = unsent.get(node)
if (requests == null) {
requests = new ArrayDeque[ClientRequest]
unsent.put(node, requests)
}
requests.add(request)
}
def removeAllTimedOut(now: Long): Collection[ClientRequest] = {
val expiredRequests = new ArrayList[ClientRequest]
for (requests <- unsent.values.asScala) {
val requestIterator = requests.iterator
var foundExpiredRequest = false
while (requestIterator.hasNext && !foundExpiredRequest) {
val request = requestIterator.next
val elapsedMs = Math.max(0, now - request.createdTimeMs)
if (elapsedMs > request.requestTimeoutMs) {
expiredRequests.add(request)
requestIterator.remove()
foundExpiredRequest = true
}
}
}
expiredRequests
}
def clean(): Unit = {
val iterator = unsent.values.iterator
while (iterator.hasNext) {
val requests = iterator.next
if (requests.isEmpty)
iterator.remove()
}
}
def iterator(): Iterator[Entry[Node, ArrayDeque[ClientRequest]]] = {
unsent.entrySet().iterator()
}
def requestIterator(node: Node): Iterator[ClientRequest] = {
val requests = unsent.get(node)
if (requests == null)
Collections.emptyIterator[ClientRequest]
else
requests.iterator
}
def nodes: java.util.Set[Node] = unsent.keySet
}

View File

@ -19,8 +19,6 @@ package kafka.coordinator.transaction
import java.util
import java.util.concurrent.{BlockingQueue, ConcurrentHashMap, LinkedBlockingQueue}
import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler}
import kafka.server.{KafkaConfig, MetadataCache, RequestLocal}
import kafka.utils.Implicits._
import kafka.utils.{CoreUtils, Logging}
@ -35,6 +33,7 @@ import org.apache.kafka.common.utils.{LogContext, Time}
import org.apache.kafka.common.{Node, Reconfigurable, TopicPartition}
import org.apache.kafka.server.common.MetadataVersion.IBP_2_8_IV0
import org.apache.kafka.server.metrics.KafkaMetricsGroup
import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler}
import scala.collection.{concurrent, immutable}
import scala.jdk.CollectionConverters._
@ -183,7 +182,7 @@ class TransactionMarkerChannelManager(
}
def retryLogAppends(): Unit = {
val txnLogAppendRetries: java.util.List[PendingCompleteTxn] = new util.ArrayList[PendingCompleteTxn]()
val txnLogAppendRetries: util.List[PendingCompleteTxn] = new util.ArrayList[PendingCompleteTxn]()
txnLogAppendRetryQueue.drainTo(txnLogAppendRetries)
txnLogAppendRetries.forEach { txnLogAppend =>
debug(s"Retry appending $txnLogAppend transaction log")
@ -191,9 +190,9 @@ class TransactionMarkerChannelManager(
}
}
override def generateRequests(): Iterable[RequestAndCompletionHandler] = {
override def generateRequests(): util.Collection[RequestAndCompletionHandler] = {
retryLogAppends()
val txnIdAndMarkerEntries: java.util.List[TxnIdAndMarkerEntry] = new util.ArrayList[TxnIdAndMarkerEntry]()
val txnIdAndMarkerEntries: util.List[TxnIdAndMarkerEntry] = new util.ArrayList[TxnIdAndMarkerEntry]()
markersQueueForUnknownBroker.forEachTxnTopicPartition { case (_, queue) =>
queue.drainTo(txnIdAndMarkerEntries)
}
@ -221,13 +220,13 @@ class TransactionMarkerChannelManager(
val requestCompletionHandler = new TransactionMarkerRequestCompletionHandler(node.id, txnStateManager, this, entries)
val request = new WriteTxnMarkersRequest.Builder(writeTxnMarkersRequestVersion, markersToSend)
RequestAndCompletionHandler(
new RequestAndCompletionHandler(
currentTimeMs,
node,
request,
requestCompletionHandler
)
}
}.asJavaCollection
}
private def writeTxnCompletion(pendingCompleteTxn: PendingCompleteTxn): Unit = {

View File

@ -16,7 +16,6 @@
*/
package kafka.raft
import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler}
import kafka.utils.Logging
import org.apache.kafka.clients.{ClientResponse, KafkaClient}
import org.apache.kafka.common.Node
@ -26,7 +25,9 @@ import org.apache.kafka.common.requests._
import org.apache.kafka.common.utils.Time
import org.apache.kafka.raft.RaftConfig.InetAddressSpec
import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, RaftUtil}
import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler}
import java.util
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable
@ -67,17 +68,17 @@ private[raft] class RaftSendThread(
) {
private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]()
def generateRequests(): Iterable[RequestAndCompletionHandler] = {
val buffer = mutable.Buffer[RequestAndCompletionHandler]()
def generateRequests(): util.Collection[RequestAndCompletionHandler] = {
val list = new util.ArrayList[RequestAndCompletionHandler]()
while (true) {
val request = queue.poll()
if (request == null) {
return buffer
return list
} else {
buffer += request
list.add(request)
}
}
buffer
list
}
def sendRequest(request: RequestAndCompletionHandler): Unit = {
@ -142,11 +143,11 @@ class KafkaNetworkChannel(
endpoints.get(request.destinationId) match {
case Some(node) =>
requestThread.sendRequest(RequestAndCompletionHandler(
requestThread.sendRequest(new RequestAndCompletionHandler(
request.createdTimeMs,
destination = node,
request = buildRequest(request.data),
handler = onComplete
node,
buildRequest(request.data),
onComplete
))
case None =>

View File

@ -17,14 +17,16 @@
package kafka.server
import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler}
import kafka.utils.Logging
import org.apache.kafka.clients.{ClientResponse, NetworkClient, RequestCompletionHandler}
import org.apache.kafka.common.{Node, TopicPartition}
import org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTransaction, AddPartitionsToTxnTransactionCollection}
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.requests.{AddPartitionsToTxnRequest, AddPartitionsToTxnResponse}
import org.apache.kafka.common.utils.Time
import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler}
import java.util
import scala.collection.mutable
object AddPartitionsToTxnManager {
@ -37,7 +39,10 @@ class TransactionDataAndCallbacks(val transactionData: AddPartitionsToTxnTransac
class AddPartitionsToTxnManager(config: KafkaConfig, client: NetworkClient, time: Time)
extends InterBrokerSendThread("AddPartitionsToTxnSenderThread-" + config.brokerId, client, config.requestTimeoutMs, time) {
extends InterBrokerSendThread("AddPartitionsToTxnSenderThread-" + config.brokerId, client, config.requestTimeoutMs, time)
with Logging {
this.logIdent = logPrefix
private val inflightNodes = mutable.HashSet[Node]()
private val nodesToTransactions = mutable.Map[Node, TransactionDataAndCallbacks]()
@ -157,20 +162,20 @@ class AddPartitionsToTxnManager(config: KafkaConfig, client: NetworkClient, time
}
}
override def generateRequests(): Iterable[RequestAndCompletionHandler] = {
override def generateRequests(): util.Collection[RequestAndCompletionHandler] = {
// build and add requests to queue
val buffer = mutable.Buffer[RequestAndCompletionHandler]()
val list = new util.ArrayList[RequestAndCompletionHandler]()
val currentTimeMs = time.milliseconds()
val removedNodes = mutable.Set[Node]()
nodesToTransactions.synchronized {
nodesToTransactions.foreach { case (node, transactionDataAndCallbacks) =>
if (!inflightNodes.contains(node)) {
buffer += RequestAndCompletionHandler(
list.add(new RequestAndCompletionHandler(
currentTimeMs,
node,
AddPartitionsToTxnRequest.Builder.forBroker(transactionDataAndCallbacks.transactionData),
new AddPartitionsToTxnHandler(node, transactionDataAndCallbacks)
)
))
removedNodes.add(node)
}
@ -180,7 +185,7 @@ class AddPartitionsToTxnManager(config: KafkaConfig, client: NetworkClient, time
nodesToTransactions.remove(node)
}
}
buffer
list
}
}

View File

@ -19,7 +19,6 @@ package kafka.server
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.atomic.AtomicReference
import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler}
import kafka.raft.RaftManager
import kafka.server.metadata.ZkMetadataCache
import kafka.utils.Logging
@ -33,7 +32,9 @@ import org.apache.kafka.common.security.JaasContext
import org.apache.kafka.common.security.auth.SecurityProtocol
import org.apache.kafka.common.utils.{LogContext, Time}
import org.apache.kafka.server.common.ApiMessageAndVersion
import org.apache.kafka.server.util.{InterBrokerSendThread, RequestAndCompletionHandler}
import java.util
import scala.collection.Seq
import scala.compat.java8.OptionConverters._
import scala.jdk.CollectionConverters._
@ -306,8 +307,10 @@ class BrokerToControllerRequestThread(
initialNetworkClient,
Math.min(Int.MaxValue, Math.min(config.controllerSocketTimeoutMs, retryTimeoutMs)).toInt,
time,
isInterruptible = false
) {
false
) with Logging {
this.logIdent = logPrefix
private def maybeResetNetworkClient(controllerInformation: ControllerInformation): Unit = {
if (isNetworkClientForZkController != controllerInformation.isZkController) {
@ -354,7 +357,7 @@ class BrokerToControllerRequestThread(
requestQueue.size
}
override def generateRequests(): Iterable[RequestAndCompletionHandler] = {
override def generateRequests(): util.Collection[RequestAndCompletionHandler] = {
val currentTimeMs = time.milliseconds()
val requestIter = requestQueue.iterator()
while (requestIter.hasNext) {
@ -366,16 +369,16 @@ class BrokerToControllerRequestThread(
val controllerAddress = activeControllerAddress()
if (controllerAddress.isDefined) {
requestIter.remove()
return Some(RequestAndCompletionHandler(
return util.Collections.singletonList(new RequestAndCompletionHandler(
time.milliseconds(),
controllerAddress.get,
request.request,
handleResponse(request)
response => handleResponse(request)(response)
))
}
}
}
None
util.Collections.emptyList()
}
private[server] def handleResponse(queueItem: BrokerToControllerQueueItem)(response: ClientResponse): Unit = {
@ -426,7 +429,7 @@ class BrokerToControllerRequestThread(
case None =>
// need to backoff to avoid tight loops
debug("No controller provided, retrying after backoff")
super.pollOnce(maxTimeoutMs = 100)
super.pollOnce(100)
}
}
}

View File

@ -1,256 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.common
import org.apache.kafka.clients.{ClientRequest, ClientResponse, NetworkClient, RequestCompletionHandler}
import org.apache.kafka.common.Node
import org.apache.kafka.common.errors.{AuthenticationException, DisconnectException}
import org.apache.kafka.common.protocol.ApiKeys
import org.apache.kafka.common.requests.AbstractRequest
import org.apache.kafka.server.util.MockTime
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test
import org.mockito.ArgumentMatchers.{any, anyLong, same}
import org.mockito.ArgumentMatchers
import org.mockito.Mockito.{mock, verify, when}
import java.util
import scala.collection.mutable
class InterBrokerSendThreadTest {
private val time = new MockTime()
private val networkClient: NetworkClient = mock(classOf[NetworkClient])
private val completionHandler = new StubCompletionHandler
private val requestTimeoutMs = 1000
class TestInterBrokerSendThread(networkClient: NetworkClient = networkClient,
exceptionCallback: Throwable => Unit = t => throw t)
extends InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) {
private val queue = mutable.Queue[RequestAndCompletionHandler]()
def enqueue(request: RequestAndCompletionHandler): Unit = {
queue += request
}
override def generateRequests(): Iterable[RequestAndCompletionHandler] = {
if (queue.isEmpty) {
None
} else {
Some(queue.dequeue())
}
}
override def pollOnce(maxTimeoutMs: Long): Unit = {
try super.pollOnce(maxTimeoutMs)
catch {
case e: Throwable => exceptionCallback(e)
}
}
}
@Test
def shutdownThreadShouldNotCauseException(): Unit = {
// InterBrokerSendThread#shutdown calls NetworkClient#initiateClose first so NetworkClient#poll
// can throw DisconnectException when thread is running
when(networkClient.poll(anyLong(), anyLong())).thenThrow(new DisconnectException())
var exception: Throwable = null
val thread = new TestInterBrokerSendThread(networkClient, e => exception = e)
thread.shutdown()
thread.pollOnce(100)
verify(networkClient).poll(anyLong(), anyLong())
assertNull(exception)
}
@Test
def shouldNotSendAnythingWhenNoRequests(): Unit = {
val sendThread = new TestInterBrokerSendThread()
// poll is always called but there should be no further invocations on NetworkClient
when(networkClient.poll(anyLong(), anyLong()))
.thenReturn(new util.ArrayList[ClientResponse]())
sendThread.doWork()
verify(networkClient).poll(anyLong(), anyLong())
assertFalse(completionHandler.executedWithDisconnectedResponse)
}
@Test
def shouldCreateClientRequestAndSendWhenNodeIsReady(): Unit = {
val request = new StubRequestBuilder()
val node = new Node(1, "", 8080)
val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler)
val sendThread = new TestInterBrokerSendThread()
val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler)
when(networkClient.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong(),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler)))
.thenReturn(clientRequest)
when(networkClient.ready(node, time.milliseconds()))
.thenReturn(true)
when(networkClient.poll(anyLong(), anyLong()))
.thenReturn(new util.ArrayList[ClientResponse]())
sendThread.enqueue(handler)
sendThread.doWork()
verify(networkClient).newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong(),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler))
verify(networkClient).ready(any[Node], anyLong())
verify(networkClient).send(same(clientRequest), anyLong())
verify(networkClient).poll(anyLong(), anyLong())
assertFalse(completionHandler.executedWithDisconnectedResponse)
}
@Test
def shouldCallCompletionHandlerWithDisconnectedResponseWhenNodeNotReady(): Unit = {
val request = new StubRequestBuilder
val node = new Node(1, "", 8080)
val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler)
val sendThread = new TestInterBrokerSendThread()
val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler)
when(networkClient.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong(),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler)))
.thenReturn(clientRequest)
when(networkClient.ready(node, time.milliseconds()))
.thenReturn(false)
when(networkClient.connectionDelay(any[Node], anyLong()))
.thenReturn(0)
when(networkClient.poll(anyLong(), anyLong()))
.thenReturn(new util.ArrayList[ClientResponse]())
when(networkClient.connectionFailed(node))
.thenReturn(true)
when(networkClient.authenticationException(node))
.thenReturn(new AuthenticationException(""))
sendThread.enqueue(handler)
sendThread.doWork()
verify(networkClient).newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong,
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler))
verify(networkClient).ready(any[Node], anyLong)
verify(networkClient).connectionDelay(any[Node], anyLong)
verify(networkClient).poll(anyLong, anyLong)
verify(networkClient).connectionFailed(any[Node])
verify(networkClient).authenticationException(any[Node])
assertTrue(completionHandler.executedWithDisconnectedResponse)
}
@Test
def testFailingExpiredRequests(): Unit = {
val request = new StubRequestBuilder()
val node = new Node(1, "", 8080)
val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler)
val sendThread = new TestInterBrokerSendThread()
val clientRequest = new ClientRequest("dest",
request,
0,
"1",
time.milliseconds(),
true,
requestTimeoutMs,
handler.handler)
time.sleep(1500)
when(networkClient.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
ArgumentMatchers.eq(handler.creationTimeMs),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler)))
.thenReturn(clientRequest)
// make the node unready so the request is not cleared
when(networkClient.ready(node, time.milliseconds()))
.thenReturn(false)
when(networkClient.connectionDelay(any[Node], anyLong()))
.thenReturn(0)
when(networkClient.poll(anyLong(), anyLong()))
.thenReturn(new util.ArrayList[ClientResponse]())
// rule out disconnects so the request stays for the expiry check
when(networkClient.connectionFailed(node))
.thenReturn(false)
sendThread.enqueue(handler)
sendThread.doWork()
verify(networkClient).newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
ArgumentMatchers.eq(handler.creationTimeMs),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler))
verify(networkClient).ready(any[Node], anyLong)
verify(networkClient).connectionDelay(any[Node], anyLong)
verify(networkClient).poll(anyLong, anyLong)
verify(networkClient).connectionFailed(any[Node])
assertFalse(sendThread.hasUnsentRequests)
assertTrue(completionHandler.executedWithDisconnectedResponse)
}
private class StubRequestBuilder extends AbstractRequest.Builder(ApiKeys.END_TXN) {
override def build(version: Short): Nothing = ???
}
private class StubCompletionHandler extends RequestCompletionHandler {
var executedWithDisconnectedResponse = false
var response: ClientResponse = _
override def onComplete(response: ClientResponse): Unit = {
this.executedWithDisconnectedResponse = response.wasDisconnected()
this.response = response
}
}
}

View File

@ -384,7 +384,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren
new WriteTxnMarkersResponse(pidErrorMap)
}
synchronized {
txnMarkerChannelManager.generateRequests().foreach { requestAndHandler =>
txnMarkerChannelManager.generateRequests().asScala.foreach { requestAndHandler =>
val request = requestAndHandler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()
val response = createResponse(request)
requestAndHandler.handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1),

View File

@ -20,8 +20,6 @@ import java.util
import java.util.Arrays.asList
import java.util.Collections
import java.util.concurrent.{Callable, Executors, Future}
import kafka.common.RequestAndCompletionHandler
import kafka.server.{KafkaConfig, MetadataCache}
import kafka.utils.TestUtils
import org.apache.kafka.clients.{ClientResponse, NetworkClient}
@ -31,6 +29,7 @@ import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, Write
import org.apache.kafka.common.utils.MockTime
import org.apache.kafka.common.{Node, TopicPartition}
import org.apache.kafka.server.metrics.KafkaYammerMetrics
import org.apache.kafka.server.util.RequestAndCompletionHandler
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.Test
import org.mockito.ArgumentMatchers.any
@ -129,7 +128,7 @@ class TransactionMarkerChannelManagerTest {
response)
TestUtils.waitUntilTrue(() => {
val requests = channelManager.generateRequests()
val requests = channelManager.generateRequests().asScala
if (requests.nonEmpty) {
assertEquals(1, requests.size)
val request = requests.head
@ -194,7 +193,7 @@ class TransactionMarkerChannelManagerTest {
val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(),
asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition2)))).build()
val requests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler =>
val requests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().asScala.map { handler =>
(handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build())
}.toMap
@ -264,13 +263,13 @@ class TransactionMarkerChannelManagerTest {
val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(),
asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition2)))).build()
val firstDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler =>
val firstDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().asScala.map { handler =>
(handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build())
}.toMap
assertEquals(Map(broker2 -> expectedBroker2Request), firstDrainedRequests)
val secondDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler =>
val secondDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().asScala.map { handler =>
(handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build())
}.toMap
@ -345,7 +344,7 @@ class TransactionMarkerChannelManagerTest {
channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2)
val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests()
val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests().asScala
val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE))
for (requestAndHandler <- requestAndHandlers) {
@ -398,7 +397,7 @@ class TransactionMarkerChannelManagerTest {
channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2)
val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests()
val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests().asScala
val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE))
for (requestAndHandler <- requestAndHandlers) {
@ -452,7 +451,7 @@ class TransactionMarkerChannelManagerTest {
channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2)
val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests()
val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests().asScala
val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE))
for (requestAndHandler <- requestAndHandlers) {

View File

@ -17,7 +17,6 @@
package unit.kafka.server
import kafka.common.RequestAndCompletionHandler
import kafka.server.{AddPartitionsToTxnManager, KafkaConfig}
import kafka.utils.TestUtils
import org.apache.kafka.clients.{ClientResponse, NetworkClient}
@ -29,6 +28,7 @@ import org.apache.kafka.common.{Node, TopicPartition}
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.requests.{AbstractResponse, AddPartitionsToTxnRequest, AddPartitionsToTxnResponse}
import org.apache.kafka.common.utils.MockTime
import org.apache.kafka.server.util.RequestAndCompletionHandler
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
import org.mockito.Mockito.mock
@ -110,7 +110,7 @@ class AddPartitionsToTxnManagerTest {
addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId1, producerId1, producerEpoch = 0), setErrors(transaction1RetryWithOldEpochErrors))
assertEquals(expectedEpochErrors, transaction1RetryWithOldEpochErrors)
val requestsAndHandlers = addPartitionsToTxnManager.generateRequests()
val requestsAndHandlers = addPartitionsToTxnManager.generateRequests().asScala
requestsAndHandlers.foreach { requestAndHandler =>
if (requestAndHandler.destination == node0) {
assertEquals(time.milliseconds(), requestAndHandler.creationTimeMs)
@ -130,7 +130,7 @@ class AddPartitionsToTxnManagerTest {
addPartitionsToTxnManager.addTxnData(node0, transactionData(transactionalId1, producerId1), setErrors(transactionErrors))
addPartitionsToTxnManager.addTxnData(node1, transactionData(transactionalId2, producerId2), setErrors(transactionErrors))
val requestsAndHandlers = addPartitionsToTxnManager.generateRequests()
val requestsAndHandlers = addPartitionsToTxnManager.generateRequests().asScala
assertEquals(2, requestsAndHandlers.size)
// Note: handlers are tested in testAddPartitionsToTxnHandlerErrorHandling
requestsAndHandlers.foreach{ requestAndHandler =>
@ -147,7 +147,7 @@ class AddPartitionsToTxnManagerTest {
// Test creationTimeMs increases too.
time.sleep(1000)
val requestsAndHandlers2 = addPartitionsToTxnManager.generateRequests()
val requestsAndHandlers2 = addPartitionsToTxnManager.generateRequests().asScala
// The request for node1 should not be added because one request is already inflight.
assertEquals(1, requestsAndHandlers2.size)
requestsAndHandlers2.foreach { requestAndHandler =>
@ -156,7 +156,7 @@ class AddPartitionsToTxnManagerTest {
// Complete the request for node1 so the new one can go through.
requestsAndHandlers.filter(_.destination == node1).head.handler.onComplete(authenticationErrorResponse)
val requestsAndHandlers3 = addPartitionsToTxnManager.generateRequests()
val requestsAndHandlers3 = addPartitionsToTxnManager.generateRequests().asScala
assertEquals(1, requestsAndHandlers3.size)
requestsAndHandlers3.foreach { requestAndHandler =>
verifyRequest(node1, transactionalId2, producerId2, requestAndHandler)
@ -238,7 +238,7 @@ class AddPartitionsToTxnManagerTest {
}
private def receiveResponse(response: ClientResponse): Unit = {
addPartitionsToTxnManager.generateRequests().head.handler.onComplete(response)
addPartitionsToTxnManager.generateRequests().asScala.head.handler.onComplete(response)
}
private def verifyRequest(expectedDestination: Node, transactionalId: String, producerId: Long, requestAndHandler: RequestAndCompletionHandler): Unit = {

View File

@ -0,0 +1,253 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.server.util;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import org.apache.kafka.clients.ClientRequest;
import org.apache.kafka.clients.ClientResponse;
import org.apache.kafka.clients.KafkaClient;
import org.apache.kafka.clients.RequestCompletionHandler;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.errors.DisconnectException;
import org.apache.kafka.common.internals.FatalExitError;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils;
/**
* An inter-broker send thread that utilizes a non-blocking network client.
*/
public abstract class InterBrokerSendThread extends ShutdownableThread {
protected volatile KafkaClient networkClient;
private final int requestTimeoutMs;
private final Time time;
private final UnsentRequests unsentRequests;
public InterBrokerSendThread(
String name,
KafkaClient networkClient,
int requestTimeoutMs,
Time time
) {
this(name, networkClient, requestTimeoutMs, time, true);
}
public InterBrokerSendThread(
String name,
KafkaClient networkClient,
int requestTimeoutMs,
Time time,
boolean isInterruptible
) {
super(name, isInterruptible);
this.networkClient = networkClient;
this.requestTimeoutMs = requestTimeoutMs;
this.time = time;
this.unsentRequests = new UnsentRequests();
}
public abstract Collection<RequestAndCompletionHandler> generateRequests();
public boolean hasUnsentRequests() {
return unsentRequests.iterator().hasNext();
}
@Override
public void shutdown() throws InterruptedException {
initiateShutdown();
networkClient.initiateClose();
awaitShutdown();
Utils.closeQuietly(networkClient, "InterBrokerSendThread network client");
}
private void drainGeneratedRequests() {
generateRequests().forEach(request ->
unsentRequests.put(
request.destination,
networkClient.newClientRequest(
request.destination.idString(),
request.request,
request.creationTimeMs,
true,
requestTimeoutMs,
request.handler
)
)
);
}
protected void pollOnce(long maxTimeoutMs) {
try {
drainGeneratedRequests();
long now = time.milliseconds();
final long timeout = sendRequests(now, maxTimeoutMs);
networkClient.poll(timeout, now);
now = time.milliseconds();
checkDisconnects(now);
failExpiredRequests(now);
unsentRequests.clean();
} catch (FatalExitError fee) {
throw fee;
} catch (Throwable t) {
if (t instanceof DisconnectException && !networkClient.active()) {
// DisconnectException is expected when NetworkClient#initiateClose is called
return;
}
log.error("unhandled exception caught in InterBrokerSendThread", t);
// rethrow any unhandled exceptions as FatalExitError so the JVM will be terminated
// as we will be in an unknown state with potentially some requests dropped and not
// being able to make progress. Known and expected Errors should have been appropriately
// dealt with already.
throw new FatalExitError();
}
}
@Override
public void doWork() {
pollOnce(Long.MAX_VALUE);
}
private long sendRequests(long now, long maxTimeoutMs) {
long pollTimeout = maxTimeoutMs;
for (Node node : unsentRequests.nodes()) {
final Iterator<ClientRequest> requestIterator = unsentRequests.requestIterator(node);
while (requestIterator.hasNext()) {
final ClientRequest request = requestIterator.next();
if (networkClient.ready(node, now)) {
networkClient.send(request, now);
requestIterator.remove();
} else {
pollTimeout = Math.min(pollTimeout, networkClient.connectionDelay(node, now));
}
}
}
return pollTimeout;
}
private void checkDisconnects(long now) {
// any disconnects affecting requests that have already been transmitted will be handled
// by NetworkClient, so we just need to check whether connections for any of the unsent
// requests have been disconnected; if they have, then we complete the corresponding future
// and set the disconnect flag in the ClientResponse
final Iterator<Map.Entry<Node, ArrayDeque<ClientRequest>>> iterator = unsentRequests.iterator();
while (iterator.hasNext()) {
final Map.Entry<Node, ArrayDeque<ClientRequest>> entry = iterator.next();
final Node node = entry.getKey();
final ArrayDeque<ClientRequest> requests = entry.getValue();
if (!requests.isEmpty() && networkClient.connectionFailed(node)) {
iterator.remove();
for (ClientRequest request : requests) {
final AuthenticationException authenticationException = networkClient.authenticationException(node);
if (authenticationException != null) {
log.error("Failed to send the following request due to authentication error: {}", request);
}
completeWithDisconnect(request, now, authenticationException);
}
}
}
}
private void failExpiredRequests(long now) {
// clear all expired unsent requests
final Collection<ClientRequest> timedOutRequests = unsentRequests.removeAllTimedOut(now);
for (ClientRequest request : timedOutRequests) {
log.debug("Failed to send the following request after {} ms: {}", request.requestTimeoutMs(), request);
completeWithDisconnect(request, now, null);
}
}
private static void completeWithDisconnect(
ClientRequest request,
long now,
AuthenticationException authenticationException
) {
final RequestCompletionHandler handler = request.callback();
handler.onComplete(
new ClientResponse(
request.makeHeader(request.requestBuilder().latestAllowedVersion()),
handler,
request.destination(),
now /* createdTimeMs */,
now /* receivedTimeMs */,
true /* disconnected */,
null /* versionMismatch */,
authenticationException,
null
)
);
}
public void wakeup() {
networkClient.wakeup();
}
private static final class UnsentRequests {
private final Map<Node, ArrayDeque<ClientRequest>> unsent = new HashMap<>();
void put(Node node, ClientRequest request) {
ArrayDeque<ClientRequest> requests = unsent.computeIfAbsent(node, n -> new ArrayDeque<>());
requests.add(request);
}
Collection<ClientRequest> removeAllTimedOut(long now) {
final List<ClientRequest> expiredRequests = new ArrayList<>();
for (ArrayDeque<ClientRequest> requests : unsent.values()) {
final Iterator<ClientRequest> requestIterator = requests.iterator();
boolean foundExpiredRequest = false;
while (requestIterator.hasNext() && !foundExpiredRequest) {
final ClientRequest request = requestIterator.next();
final long elapsedMs = Math.max(0, now - request.createdTimeMs());
if (elapsedMs > request.requestTimeoutMs()) {
expiredRequests.add(request);
requestIterator.remove();
foundExpiredRequest = true;
}
}
}
return expiredRequests;
}
void clean() {
unsent.values().removeIf(ArrayDeque::isEmpty);
}
Iterator<Entry<Node, ArrayDeque<ClientRequest>>> iterator() {
return unsent.entrySet().iterator();
}
Iterator<ClientRequest> requestIterator(Node node) {
ArrayDeque<ClientRequest> requests = unsent.get(node);
return (requests == null) ? Collections.emptyIterator() : requests.iterator();
}
Set<Node> nodes() {
return unsent.keySet();
}
}
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.server.util;
import org.apache.kafka.clients.RequestCompletionHandler;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.requests.AbstractRequest;
public final class RequestAndCompletionHandler {
public final long creationTimeMs;
public final Node destination;
public final AbstractRequest.Builder<? extends AbstractRequest> request;
public final RequestCompletionHandler handler;
public RequestAndCompletionHandler(
long creationTimeMs,
Node destination,
AbstractRequest.Builder<? extends AbstractRequest> request,
RequestCompletionHandler handler
) {
this.creationTimeMs = creationTimeMs;
this.destination = destination;
this.request = request;
this.handler = handler;
}
@Override
public String toString() {
return "RequestAndCompletionHandler(" +
"creationTimeMs=" + creationTimeMs +
", destination=" + destination +
", request=" + request +
", handler=" + handler +
')';
}
}

View File

@ -29,7 +29,7 @@ public abstract class ShutdownableThread extends Thread {
public final String logPrefix;
private final Logger log;
protected final Logger log;
private final boolean isInterruptible;

View File

@ -0,0 +1,326 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.server.util;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Collections;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import org.apache.kafka.clients.ClientRequest;
import org.apache.kafka.clients.ClientResponse;
import org.apache.kafka.clients.KafkaClient;
import org.apache.kafka.clients.RequestCompletionHandler;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.errors.DisconnectException;
import org.apache.kafka.common.internals.FatalExitError;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.requests.AbstractRequest;
import org.apache.kafka.common.utils.Time;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentMatchers;
public class InterBrokerSendThreadTest {
private final Time time = new MockTime();
private final KafkaClient networkClient = mock(KafkaClient.class);
private final StubCompletionHandler completionHandler = new StubCompletionHandler();
private final int requestTimeoutMs = 1000;
class TestInterBrokerSendThread extends InterBrokerSendThread {
private final Consumer<Throwable> exceptionCallback;
private final Queue<RequestAndCompletionHandler> queue = new ArrayDeque<>();
TestInterBrokerSendThread() {
this(
InterBrokerSendThreadTest.this.networkClient,
t -> {
throw (t instanceof RuntimeException)
? ((RuntimeException) t)
: new RuntimeException(t);
});
}
TestInterBrokerSendThread(KafkaClient networkClient, Consumer<Throwable> exceptionCallback) {
super("name", networkClient, requestTimeoutMs, time);
this.exceptionCallback = exceptionCallback;
}
void enqueue(RequestAndCompletionHandler request) {
queue.offer(request);
}
@Override
public Collection<RequestAndCompletionHandler> generateRequests() {
return queue.isEmpty() ? Collections.emptyList() : Collections.singletonList(queue.poll());
}
@Override
protected void pollOnce(long maxTimeoutMs) {
try {
super.pollOnce(maxTimeoutMs);
} catch (Throwable t) {
exceptionCallback.accept(t);
}
}
}
@Test
public void testShutdownThreadShouldNotCauseException() throws InterruptedException, IOException {
// InterBrokerSendThread#shutdown calls NetworkClient#initiateClose first so NetworkClient#poll
// can throw DisconnectException when thread is running
when(networkClient.poll(anyLong(), anyLong())).thenThrow(new DisconnectException());
when(networkClient.active()).thenReturn(false);
AtomicReference<Throwable> exception = new AtomicReference<>();
final InterBrokerSendThread thread =
new TestInterBrokerSendThread(networkClient, exception::getAndSet);
thread.shutdown();
thread.pollOnce(100);
verify(networkClient).poll(anyLong(), anyLong());
verify(networkClient).initiateClose();
verify(networkClient).close();
verify(networkClient).active();
verifyNoMoreInteractions(networkClient);
assertNull(exception.get());
}
@Test
public void testDisconnectWithoutShutdownShouldCauseException() {
DisconnectException de = new DisconnectException();
when(networkClient.poll(anyLong(), anyLong())).thenThrow(de);
when(networkClient.active()).thenReturn(true);
AtomicReference<Throwable> throwable = new AtomicReference<>();
final InterBrokerSendThread thread =
new TestInterBrokerSendThread(networkClient, throwable::getAndSet);
thread.pollOnce(100);
verify(networkClient).poll(anyLong(), anyLong());
verify(networkClient).active();
verifyNoMoreInteractions(networkClient);
Throwable thrown = throwable.get();
assertNotNull(thrown);
assertTrue(thrown instanceof FatalExitError);
}
@Test
public void testShouldNotSendAnythingWhenNoRequests() {
final InterBrokerSendThread sendThread = new TestInterBrokerSendThread();
// poll is always called but there should be no further invocations on NetworkClient
when(networkClient.poll(anyLong(), anyLong())).thenReturn(Collections.emptyList());
sendThread.doWork();
verify(networkClient).poll(anyLong(), anyLong());
verifyNoMoreInteractions(networkClient);
assertFalse(completionHandler.executedWithDisconnectedResponse);
}
@Test
public void testShouldCreateClientRequestAndSendWhenNodeIsReady() {
final AbstractRequest.Builder<?> request = new StubRequestBuilder<>();
final Node node = new Node(1, "", 8080);
final RequestAndCompletionHandler handler =
new RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler);
final TestInterBrokerSendThread sendThread = new TestInterBrokerSendThread();
final ClientRequest clientRequest =
new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler);
when(networkClient.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong(),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler)
)).thenReturn(clientRequest);
when(networkClient.ready(node, time.milliseconds())).thenReturn(true);
when(networkClient.poll(anyLong(), anyLong())).thenReturn(Collections.emptyList());
sendThread.enqueue(handler);
sendThread.doWork();
verify(networkClient)
.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong(),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler));
verify(networkClient).ready(any(), anyLong());
verify(networkClient).send(same(clientRequest), anyLong());
verify(networkClient).poll(anyLong(), anyLong());
verifyNoMoreInteractions(networkClient);
assertFalse(completionHandler.executedWithDisconnectedResponse);
}
@Test
public void testShouldCallCompletionHandlerWithDisconnectedResponseWhenNodeNotReady() {
final AbstractRequest.Builder<?> request = new StubRequestBuilder<>();
final Node node = new Node(1, "", 8080);
final RequestAndCompletionHandler handler =
new RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler);
final TestInterBrokerSendThread sendThread = new TestInterBrokerSendThread();
final ClientRequest clientRequest =
new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler);
when(networkClient.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong(),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler)
)).thenReturn(clientRequest);
when(networkClient.ready(node, time.milliseconds())).thenReturn(false);
when(networkClient.connectionDelay(any(), anyLong())).thenReturn(0L);
when(networkClient.poll(anyLong(), anyLong())).thenReturn(Collections.emptyList());
when(networkClient.connectionFailed(node)).thenReturn(true);
when(networkClient.authenticationException(node)).thenReturn(new AuthenticationException(""));
sendThread.enqueue(handler);
sendThread.doWork();
verify(networkClient)
.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
anyLong(),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler));
verify(networkClient).ready(any(), anyLong());
verify(networkClient).connectionDelay(any(), anyLong());
verify(networkClient).poll(anyLong(), anyLong());
verify(networkClient).connectionFailed(any());
verify(networkClient).authenticationException(any());
verifyNoMoreInteractions(networkClient);
assertTrue(completionHandler.executedWithDisconnectedResponse);
}
@Test
public void testFailingExpiredRequests() {
final AbstractRequest.Builder<?> request = new StubRequestBuilder<>();
final Node node = new Node(1, "", 8080);
final RequestAndCompletionHandler handler =
new RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler);
final TestInterBrokerSendThread sendThread = new TestInterBrokerSendThread();
final ClientRequest clientRequest =
new ClientRequest(
"dest", request, 0, "1", time.milliseconds(), true, requestTimeoutMs, handler.handler);
time.sleep(1500L);
when(networkClient.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
ArgumentMatchers.eq(handler.creationTimeMs),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler)
)).thenReturn(clientRequest);
// make the node unready so the request is not cleared
when(networkClient.ready(node, time.milliseconds())).thenReturn(false);
when(networkClient.connectionDelay(any(), anyLong())).thenReturn(0L);
when(networkClient.poll(anyLong(), anyLong())).thenReturn(Collections.emptyList());
// rule out disconnects so the request stays for the expiry check
when(networkClient.connectionFailed(node)).thenReturn(false);
sendThread.enqueue(handler);
sendThread.doWork();
verify(networkClient)
.newClientRequest(
ArgumentMatchers.eq("1"),
same(handler.request),
ArgumentMatchers.eq(handler.creationTimeMs),
ArgumentMatchers.eq(true),
ArgumentMatchers.eq(requestTimeoutMs),
same(handler.handler));
verify(networkClient).ready(any(), anyLong());
verify(networkClient).connectionDelay(any(), anyLong());
verify(networkClient).poll(anyLong(), anyLong());
verify(networkClient).connectionFailed(any());
verifyNoMoreInteractions(networkClient);
assertFalse(sendThread.hasUnsentRequests());
assertTrue(completionHandler.executedWithDisconnectedResponse);
}
private static class StubRequestBuilder<T extends AbstractRequest>
extends AbstractRequest.Builder<T> {
private StubRequestBuilder() {
super(ApiKeys.END_TXN);
}
@Override
public T build(short version) {
return null;
}
}
private static class StubCompletionHandler implements RequestCompletionHandler {
public boolean executedWithDisconnectedResponse = false;
ClientResponse response = null;
@Override
public void onComplete(ClientResponse response) {
this.executedWithDisconnectedResponse = response.wasDisconnected();
this.response = response;
}
}
}