mirror of https://github.com/apache/kafka.git
KAFKA-15653: Pass requestLocal as argument to callback so we use the correct one for the thread (#14629)
With the new callback mechanism we were accidentally passing context with the wrong request local. Now include a RequestLocal as an explicit argument to the callback. Also make the arguments passed through the callback clearer by separating the method out. Added a test to ensure we use the request handler's request local and not the one passed in when the callback is executed via the request handler. Reviewers: Ismael Juma <ismael@juma.me.uk>, Divij Vaidya <diviv@amazon.com>, David Jacot <djacot@confluent.io>, Jason Gustafson <jason@confluent.io>, Artem Livshits <alivshits@confluent.io>, Jun Rao <junrao@gmail.com>,
This commit is contained in:
parent
edc7e10a74
commit
91fa196930
|
@ -24,7 +24,7 @@ import com.fasterxml.jackson.databind.JsonNode
|
|||
import com.typesafe.scalalogging.Logger
|
||||
import com.yammer.metrics.core.Meter
|
||||
import kafka.network
|
||||
import kafka.server.KafkaConfig
|
||||
import kafka.server.{KafkaConfig, RequestLocal}
|
||||
import kafka.utils.{Logging, NotNothing, Pool}
|
||||
import kafka.utils.Implicits._
|
||||
import org.apache.kafka.common.config.ConfigResource
|
||||
|
@ -80,7 +80,7 @@ object RequestChannel extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
case class CallbackRequest(fun: () => Unit,
|
||||
case class CallbackRequest(fun: RequestLocal => Unit,
|
||||
originalRequest: Request) extends BaseRequest
|
||||
|
||||
class Request(val processor: Int,
|
||||
|
|
|
@ -50,28 +50,34 @@ object KafkaRequestHandler {
|
|||
}
|
||||
|
||||
/**
|
||||
* Wrap callback to schedule it on a request thread.
|
||||
* NOTE: this function must be called on a request thread.
|
||||
* @param fun Callback function to execute
|
||||
* @return Wrapped callback that would execute `fun` on a request thread
|
||||
* Creates a wrapped callback to be executed synchronously on the calling request thread or asynchronously
|
||||
* on an arbitrary request thread.
|
||||
* NOTE: this function must be originally called from a request thread.
|
||||
* @param asyncCompletionCallback A callback method that we intend to call from the current thread or in another
|
||||
* thread after an asynchronous action completes. The RequestLocal passed in must
|
||||
* belong to the request handler thread that is executing the callback.
|
||||
* @param requestLocal The RequestLocal for the current request handler thread in case we need to execute the callback
|
||||
* function synchronously from the calling thread.
|
||||
* @return Wrapped callback will either immediately execute `asyncCompletionCallback` or schedule it on an arbitrary request thread
|
||||
* depending on where it is called
|
||||
*/
|
||||
def wrap[T](fun: T => Unit): T => Unit = {
|
||||
def wrapAsyncCallback[T](asyncCompletionCallback: (RequestLocal, T) => Unit, requestLocal: RequestLocal): T => Unit = {
|
||||
val requestChannel = threadRequestChannel.get()
|
||||
val currentRequest = threadCurrentRequest.get()
|
||||
if (requestChannel == null || currentRequest == null) {
|
||||
if (!bypassThreadCheck)
|
||||
throw new IllegalStateException("Attempted to reschedule to request handler thread from non-request handler thread.")
|
||||
T => fun(T)
|
||||
T => asyncCompletionCallback(requestLocal, T)
|
||||
} else {
|
||||
T => {
|
||||
if (threadCurrentRequest.get() != null) {
|
||||
// If the callback is actually executed on a request thread, we can directly execute
|
||||
if (threadCurrentRequest.get() == currentRequest) {
|
||||
// If the callback is actually executed on the same request thread, we can directly execute
|
||||
// it without re-scheduling it.
|
||||
fun(T)
|
||||
asyncCompletionCallback(requestLocal, T)
|
||||
} else {
|
||||
// The requestChannel and request are captured in this lambda, so when it's executed on the callback thread
|
||||
// we can re-schedule the original callback on a request thread and update the metrics accordingly.
|
||||
requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(() => fun(T), currentRequest))
|
||||
requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(newRequestLocal => asyncCompletionCallback(newRequestLocal, T), currentRequest))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -132,7 +138,7 @@ class KafkaRequestHandler(
|
|||
}
|
||||
|
||||
threadCurrentRequest.set(originalRequest)
|
||||
callback.fun()
|
||||
callback.fun(requestLocal)
|
||||
} catch {
|
||||
case e: FatalExitError =>
|
||||
completeShutdown()
|
||||
|
@ -174,6 +180,7 @@ class KafkaRequestHandler(
|
|||
|
||||
private def completeShutdown(): Unit = {
|
||||
requestLocal.close()
|
||||
threadRequestChannel.remove()
|
||||
shutdownComplete.countDown()
|
||||
}
|
||||
|
||||
|
|
|
@ -777,7 +777,6 @@ class ReplicaManager(val config: KafkaConfig,
|
|||
transactionalId: String = null,
|
||||
actionQueue: ActionQueue = this.actionQueue): Unit = {
|
||||
if (isValidRequiredAcks(requiredAcks)) {
|
||||
val sTime = time.milliseconds
|
||||
|
||||
val verificationGuards: mutable.Map[TopicPartition, VerificationGuard] = mutable.Map[TopicPartition, VerificationGuard]()
|
||||
val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition, errorsPerPartition) =
|
||||
|
@ -791,96 +790,9 @@ class ReplicaManager(val config: KafkaConfig,
|
|||
(verifiedEntries.toMap, unverifiedEntries.toMap, errorEntries.toMap)
|
||||
}
|
||||
|
||||
def appendEntries(allEntries: Map[TopicPartition, MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = {
|
||||
val verifiedEntries =
|
||||
if (unverifiedEntries.isEmpty)
|
||||
allEntries
|
||||
else
|
||||
allEntries.filter { case (tp, _) =>
|
||||
!unverifiedEntries.contains(tp)
|
||||
}
|
||||
|
||||
val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed,
|
||||
origin, verifiedEntries, requiredAcks, requestLocal, verificationGuards.toMap)
|
||||
debug("Produce to local log in %d ms".format(time.milliseconds - sTime))
|
||||
|
||||
val errorResults = (unverifiedEntries ++ errorsPerPartition).map {
|
||||
case (topicPartition, error) =>
|
||||
// translate transaction coordinator errors to known producer response errors
|
||||
val customException =
|
||||
error match {
|
||||
case Errors.INVALID_TXN_STATE => Some(error.exception("Partition was not added to the transaction"))
|
||||
case Errors.CONCURRENT_TRANSACTIONS |
|
||||
Errors.COORDINATOR_LOAD_IN_PROGRESS |
|
||||
Errors.COORDINATOR_NOT_AVAILABLE |
|
||||
Errors.NOT_COORDINATOR => Some(new NotEnoughReplicasException(
|
||||
s"Unable to verify the partition has been added to the transaction. Underlying error: ${error.toString}"))
|
||||
case _ => None
|
||||
}
|
||||
topicPartition -> LogAppendResult(
|
||||
LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
|
||||
Some(customException.getOrElse(error.exception)),
|
||||
hasCustomErrorMessage = customException.isDefined
|
||||
)
|
||||
}
|
||||
|
||||
val allResults = localProduceResults ++ errorResults
|
||||
val produceStatus = allResults.map { case (topicPartition, result) =>
|
||||
topicPartition -> ProducePartitionStatus(
|
||||
result.info.lastOffset + 1, // required offset
|
||||
new PartitionResponse(
|
||||
result.error,
|
||||
result.info.firstOffset,
|
||||
result.info.lastOffset,
|
||||
result.info.logAppendTime,
|
||||
result.info.logStartOffset,
|
||||
result.info.recordErrors,
|
||||
result.errorMessage
|
||||
)
|
||||
) // response status
|
||||
}
|
||||
|
||||
actionQueue.add {
|
||||
() => allResults.foreach { case (topicPartition, result) =>
|
||||
val requestKey = TopicPartitionOperationKey(topicPartition)
|
||||
result.info.leaderHwChange match {
|
||||
case LeaderHwChange.INCREASED =>
|
||||
// some delayed operations may be unblocked after HW changed
|
||||
delayedProducePurgatory.checkAndComplete(requestKey)
|
||||
delayedFetchPurgatory.checkAndComplete(requestKey)
|
||||
delayedDeleteRecordsPurgatory.checkAndComplete(requestKey)
|
||||
case LeaderHwChange.SAME =>
|
||||
// probably unblock some follower fetch requests since log end offset has been updated
|
||||
delayedFetchPurgatory.checkAndComplete(requestKey)
|
||||
case LeaderHwChange.NONE =>
|
||||
// nothing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats })
|
||||
|
||||
if (delayedProduceRequestRequired(requiredAcks, allEntries, allResults)) {
|
||||
// create delayed produce operation
|
||||
val produceMetadata = ProduceMetadata(requiredAcks, produceStatus)
|
||||
val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock)
|
||||
|
||||
// create a list of (topic, partition) pairs to use as keys for this delayed produce operation
|
||||
val producerRequestKeys = allEntries.keys.map(TopicPartitionOperationKey(_)).toSeq
|
||||
|
||||
// try to complete the request immediately, otherwise put it into the purgatory
|
||||
// this is because while the delayed produce operation is being created, new
|
||||
// requests may arrive and hence make this operation completable.
|
||||
delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)
|
||||
} else {
|
||||
// we can respond immediately
|
||||
val produceResponseStatus = produceStatus.map { case (k, status) => k -> status.responseStatus }
|
||||
responseCallback(produceResponseStatus)
|
||||
}
|
||||
}
|
||||
|
||||
if (notYetVerifiedEntriesPerPartition.isEmpty || addPartitionsToTxnManager.isEmpty) {
|
||||
appendEntries(verifiedEntriesPerPartition)(Map.empty)
|
||||
appendEntries(verifiedEntriesPerPartition, internalTopicsAllowed, origin, requiredAcks, verificationGuards.toMap,
|
||||
errorsPerPartition, recordConversionStatsCallback, timeout, responseCallback, delayedProduceLock)(requestLocal, Map.empty)
|
||||
} else {
|
||||
// For unverified entries, send a request to verify. When verified, the append process will proceed via the callback.
|
||||
// We verify above that all partitions use the same producer ID.
|
||||
|
@ -890,7 +802,20 @@ class ReplicaManager(val config: KafkaConfig,
|
|||
producerId = batchInfo.producerId,
|
||||
producerEpoch = batchInfo.producerEpoch,
|
||||
topicPartitions = notYetVerifiedEntriesPerPartition.keySet.toSeq,
|
||||
callback = KafkaRequestHandler.wrap(appendEntries(entriesPerPartition)(_))
|
||||
callback = KafkaRequestHandler.wrapAsyncCallback(
|
||||
appendEntries(
|
||||
entriesPerPartition,
|
||||
internalTopicsAllowed,
|
||||
origin,
|
||||
requiredAcks,
|
||||
verificationGuards.toMap,
|
||||
errorsPerPartition,
|
||||
recordConversionStatsCallback,
|
||||
timeout,
|
||||
responseCallback,
|
||||
delayedProduceLock
|
||||
),
|
||||
requestLocal)
|
||||
))
|
||||
}
|
||||
} else {
|
||||
|
@ -908,6 +833,110 @@ class ReplicaManager(val config: KafkaConfig,
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Note: This method can be used as a callback in a different request thread. Ensure that correct RequestLocal
|
||||
* is passed when executing this method. Accessing non-thread-safe data structures should be avoided if possible.
|
||||
*/
|
||||
private def appendEntries(allEntries: Map[TopicPartition, MemoryRecords],
|
||||
internalTopicsAllowed: Boolean,
|
||||
origin: AppendOrigin,
|
||||
requiredAcks: Short,
|
||||
verificationGuards: Map[TopicPartition, VerificationGuard],
|
||||
errorsPerPartition: Map[TopicPartition, Errors],
|
||||
recordConversionStatsCallback: Map[TopicPartition, RecordConversionStats] => Unit,
|
||||
timeout: Long,
|
||||
responseCallback: Map[TopicPartition, PartitionResponse] => Unit,
|
||||
delayedProduceLock: Option[Lock])
|
||||
(requestLocal: RequestLocal, unverifiedEntries: Map[TopicPartition, Errors]): Unit = {
|
||||
val sTime = time.milliseconds
|
||||
val verifiedEntries =
|
||||
if (unverifiedEntries.isEmpty)
|
||||
allEntries
|
||||
else
|
||||
allEntries.filter { case (tp, _) =>
|
||||
!unverifiedEntries.contains(tp)
|
||||
}
|
||||
|
||||
val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed,
|
||||
origin, verifiedEntries, requiredAcks, requestLocal, verificationGuards.toMap)
|
||||
debug("Produce to local log in %d ms".format(time.milliseconds - sTime))
|
||||
|
||||
val errorResults = (unverifiedEntries ++ errorsPerPartition).map {
|
||||
case (topicPartition, error) =>
|
||||
// translate transaction coordinator errors to known producer response errors
|
||||
val customException =
|
||||
error match {
|
||||
case Errors.INVALID_TXN_STATE => Some(error.exception("Partition was not added to the transaction"))
|
||||
case Errors.CONCURRENT_TRANSACTIONS |
|
||||
Errors.COORDINATOR_LOAD_IN_PROGRESS |
|
||||
Errors.COORDINATOR_NOT_AVAILABLE |
|
||||
Errors.NOT_COORDINATOR => Some(new NotEnoughReplicasException(
|
||||
s"Unable to verify the partition has been added to the transaction. Underlying error: ${error.toString}"))
|
||||
case _ => None
|
||||
}
|
||||
topicPartition -> LogAppendResult(
|
||||
LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
|
||||
Some(customException.getOrElse(error.exception)),
|
||||
hasCustomErrorMessage = customException.isDefined
|
||||
)
|
||||
}
|
||||
|
||||
val allResults = localProduceResults ++ errorResults
|
||||
val produceStatus = allResults.map { case (topicPartition, result) =>
|
||||
topicPartition -> ProducePartitionStatus(
|
||||
result.info.lastOffset + 1, // required offset
|
||||
new PartitionResponse(
|
||||
result.error,
|
||||
result.info.firstOffset,
|
||||
result.info.lastOffset,
|
||||
result.info.logAppendTime,
|
||||
result.info.logStartOffset,
|
||||
result.info.recordErrors,
|
||||
result.errorMessage
|
||||
)
|
||||
) // response status
|
||||
}
|
||||
|
||||
actionQueue.add {
|
||||
() =>
|
||||
allResults.foreach { case (topicPartition, result) =>
|
||||
val requestKey = TopicPartitionOperationKey(topicPartition)
|
||||
result.info.leaderHwChange match {
|
||||
case LeaderHwChange.INCREASED =>
|
||||
// some delayed operations may be unblocked after HW changed
|
||||
delayedProducePurgatory.checkAndComplete(requestKey)
|
||||
delayedFetchPurgatory.checkAndComplete(requestKey)
|
||||
delayedDeleteRecordsPurgatory.checkAndComplete(requestKey)
|
||||
case LeaderHwChange.SAME =>
|
||||
// probably unblock some follower fetch requests since log end offset has been updated
|
||||
delayedFetchPurgatory.checkAndComplete(requestKey)
|
||||
case LeaderHwChange.NONE =>
|
||||
// nothing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats })
|
||||
|
||||
if (delayedProduceRequestRequired(requiredAcks, allEntries, allResults)) {
|
||||
// create delayed produce operation
|
||||
val produceMetadata = ProduceMetadata(requiredAcks, produceStatus)
|
||||
val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock)
|
||||
|
||||
// create a list of (topic, partition) pairs to use as keys for this delayed produce operation
|
||||
val producerRequestKeys = allEntries.keys.map(TopicPartitionOperationKey(_)).toSeq
|
||||
|
||||
// try to complete the request immediately, otherwise put it into the purgatory
|
||||
// this is because while the delayed produce operation is being created, new
|
||||
// requests may arrive and hence make this operation completable.
|
||||
delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)
|
||||
} else {
|
||||
// we can respond immediately
|
||||
val produceResponseStatus = produceStatus.map { case (k, status) => k -> status.responseStatus }
|
||||
responseCallback(produceResponseStatus)
|
||||
}
|
||||
}
|
||||
|
||||
private def partitionEntriesForVerification(verificationGuards: mutable.Map[TopicPartition, VerificationGuard],
|
||||
entriesPerPartition: Map[TopicPartition, MemoryRecords],
|
||||
verifiedEntries: mutable.Map[TopicPartition, MemoryRecords],
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.kafka.common.network.{ClientInformation, ListenerName}
|
|||
import org.apache.kafka.common.protocol.ApiKeys
|
||||
import org.apache.kafka.common.requests.{RequestContext, RequestHeader}
|
||||
import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
|
||||
import org.apache.kafka.common.utils.{MockTime, Time}
|
||||
import org.apache.kafka.common.utils.{BufferSupplier, MockTime, Time}
|
||||
import org.apache.kafka.server.log.remote.storage.{RemoteLogManagerConfig, RemoteStorageMetrics}
|
||||
import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
|
||||
import org.junit.jupiter.api.Test
|
||||
|
@ -32,7 +32,7 @@ import org.junit.jupiter.params.ParameterizedTest
|
|||
import org.junit.jupiter.params.provider.ValueSource
|
||||
import org.mockito.ArgumentMatchers
|
||||
import org.mockito.ArgumentMatchers.any
|
||||
import org.mockito.Mockito.{mock, when}
|
||||
import org.mockito.Mockito.{mock, times, verify, when}
|
||||
|
||||
import java.net.InetAddress
|
||||
import java.nio.ByteBuffer
|
||||
|
@ -57,10 +57,12 @@ class KafkaRequestHandlerTest {
|
|||
when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
|
||||
time.sleep(2)
|
||||
// Prepare the callback.
|
||||
val callback = KafkaRequestHandler.wrap((ms: Int) => {
|
||||
time.sleep(ms)
|
||||
handler.stop()
|
||||
})
|
||||
val callback = KafkaRequestHandler.wrapAsyncCallback(
|
||||
(reqLocal: RequestLocal, ms: Int) => {
|
||||
time.sleep(ms)
|
||||
handler.stop()
|
||||
},
|
||||
RequestLocal.NoCaching)
|
||||
// Execute the callback asynchronously.
|
||||
CompletableFuture.runAsync(() => callback(1))
|
||||
request.apiLocalCompleteTimeNanos = time.nanoseconds
|
||||
|
@ -94,9 +96,11 @@ class KafkaRequestHandlerTest {
|
|||
when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
|
||||
handledCount = handledCount + 1
|
||||
// Prepare the callback.
|
||||
val callback = KafkaRequestHandler.wrap((ms: Int) => {
|
||||
handler.stop()
|
||||
})
|
||||
val callback = KafkaRequestHandler.wrapAsyncCallback(
|
||||
(reqLocal: RequestLocal, ms: Int) => {
|
||||
handler.stop()
|
||||
},
|
||||
RequestLocal.NoCaching)
|
||||
// Execute the callback asynchronously.
|
||||
CompletableFuture.runAsync(() => callback(1))
|
||||
}
|
||||
|
@ -111,6 +115,75 @@ class KafkaRequestHandlerTest {
|
|||
assertEquals(1, tryCompleteActionCount)
|
||||
}
|
||||
|
||||
@Test
|
||||
def testHandlingCallbackOnNewThread(): Unit = {
|
||||
val time = new MockTime()
|
||||
val metrics = mock(classOf[RequestChannel.Metrics])
|
||||
val apiHandler = mock(classOf[ApiRequestHandler])
|
||||
val requestChannel = new RequestChannel(10, "", time, metrics)
|
||||
val handler = new KafkaRequestHandler(0, 0, mock(classOf[Meter]), new AtomicInteger(1), requestChannel, apiHandler, time)
|
||||
|
||||
val originalRequestLocal = mock(classOf[RequestLocal])
|
||||
|
||||
var handledCount = 0
|
||||
|
||||
val request = makeRequest(time, metrics)
|
||||
requestChannel.sendRequest(request)
|
||||
|
||||
when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
|
||||
// Prepare the callback.
|
||||
val callback = KafkaRequestHandler.wrapAsyncCallback(
|
||||
(reqLocal: RequestLocal, ms: Int) => {
|
||||
reqLocal.bufferSupplier.close()
|
||||
handledCount = handledCount + 1
|
||||
handler.stop()
|
||||
},
|
||||
originalRequestLocal)
|
||||
// Execute the callback asynchronously.
|
||||
CompletableFuture.runAsync(() => callback(1))
|
||||
}
|
||||
|
||||
handler.run()
|
||||
// Verify that we don't use the request local that we passed in.
|
||||
verify(originalRequestLocal, times(0)).bufferSupplier
|
||||
assertEquals(1, handledCount)
|
||||
}
|
||||
|
||||
@Test
|
||||
def testCallbackOnSameThread(): Unit = {
|
||||
val time = new MockTime()
|
||||
val metrics = mock(classOf[RequestChannel.Metrics])
|
||||
val apiHandler = mock(classOf[ApiRequestHandler])
|
||||
val requestChannel = new RequestChannel(10, "", time, metrics)
|
||||
val handler = new KafkaRequestHandler(0, 0, mock(classOf[Meter]), new AtomicInteger(1), requestChannel, apiHandler, time)
|
||||
|
||||
val originalRequestLocal = mock(classOf[RequestLocal])
|
||||
when(originalRequestLocal.bufferSupplier).thenReturn(BufferSupplier.create())
|
||||
|
||||
var handledCount = 0
|
||||
|
||||
val request = makeRequest(time, metrics)
|
||||
requestChannel.sendRequest(request)
|
||||
|
||||
when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
|
||||
// Prepare the callback.
|
||||
val callback = KafkaRequestHandler.wrapAsyncCallback(
|
||||
(reqLocal: RequestLocal, ms: Int) => {
|
||||
reqLocal.bufferSupplier.close()
|
||||
handledCount = handledCount + 1
|
||||
handler.stop()
|
||||
},
|
||||
originalRequestLocal)
|
||||
// Execute the callback before the request returns.
|
||||
callback(1)
|
||||
}
|
||||
|
||||
handler.run()
|
||||
// Verify that we do use the request local that we passed in.
|
||||
verify(originalRequestLocal, times(1)).bufferSupplier
|
||||
assertEquals(1, handledCount)
|
||||
}
|
||||
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(booleans = Array(true, false))
|
||||
|
|
Loading…
Reference in New Issue