KAFKA-15653: Pass requestLocal as argument to callback so we use the correct one for the thread (#14629)

With the new callback mechanism we were accidentally passing context with the wrong request local. Now include a RequestLocal as an explicit argument to the callback.

Also make the arguments passed through the callback clearer by separating the method out.

Added a test to ensure we use the request handler's request local and not the one passed in when the callback is executed via the request handler.

Reviewers: Ismael Juma <ismael@juma.me.uk>,  Divij Vaidya <diviv@amazon.com>, David Jacot <djacot@confluent.io>, Jason Gustafson <jason@confluent.io>, Artem Livshits <alivshits@confluent.io>, Jun Rao <junrao@gmail.com>,
This commit is contained in:
Justine Olshan 2023-11-07 15:14:17 -08:00 committed by GitHub
parent edc7e10a74
commit 91fa196930
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 222 additions and 113 deletions

View File

@ -24,7 +24,7 @@ import com.fasterxml.jackson.databind.JsonNode
import com.typesafe.scalalogging.Logger
import com.yammer.metrics.core.Meter
import kafka.network
import kafka.server.KafkaConfig
import kafka.server.{KafkaConfig, RequestLocal}
import kafka.utils.{Logging, NotNothing, Pool}
import kafka.utils.Implicits._
import org.apache.kafka.common.config.ConfigResource
@ -80,7 +80,7 @@ object RequestChannel extends Logging {
}
}
case class CallbackRequest(fun: () => Unit,
case class CallbackRequest(fun: RequestLocal => Unit,
originalRequest: Request) extends BaseRequest
class Request(val processor: Int,

View File

@ -50,28 +50,34 @@ object KafkaRequestHandler {
}
/**
* Wrap callback to schedule it on a request thread.
* NOTE: this function must be called on a request thread.
* @param fun Callback function to execute
* @return Wrapped callback that would execute `fun` on a request thread
* Creates a wrapped callback to be executed synchronously on the calling request thread or asynchronously
* on an arbitrary request thread.
* NOTE: this function must be originally called from a request thread.
* @param asyncCompletionCallback A callback method that we intend to call from the current thread or in another
* thread after an asynchronous action completes. The RequestLocal passed in must
* belong to the request handler thread that is executing the callback.
* @param requestLocal The RequestLocal for the current request handler thread in case we need to execute the callback
* function synchronously from the calling thread.
* @return Wrapped callback will either immediately execute `asyncCompletionCallback` or schedule it on an arbitrary request thread
* depending on where it is called
*/
def wrap[T](fun: T => Unit): T => Unit = {
def wrapAsyncCallback[T](asyncCompletionCallback: (RequestLocal, T) => Unit, requestLocal: RequestLocal): T => Unit = {
val requestChannel = threadRequestChannel.get()
val currentRequest = threadCurrentRequest.get()
if (requestChannel == null || currentRequest == null) {
if (!bypassThreadCheck)
throw new IllegalStateException("Attempted to reschedule to request handler thread from non-request handler thread.")
T => fun(T)
T => asyncCompletionCallback(requestLocal, T)
} else {
T => {
if (threadCurrentRequest.get() != null) {
// If the callback is actually executed on a request thread, we can directly execute
if (threadCurrentRequest.get() == currentRequest) {
// If the callback is actually executed on the same request thread, we can directly execute
// it without re-scheduling it.
fun(T)
asyncCompletionCallback(requestLocal, T)
} else {
// The requestChannel and request are captured in this lambda, so when it's executed on the callback thread
// we can re-schedule the original callback on a request thread and update the metrics accordingly.
requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(() => fun(T), currentRequest))
requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(newRequestLocal => asyncCompletionCallback(newRequestLocal, T), currentRequest))
}
}
}
@ -132,7 +138,7 @@ class KafkaRequestHandler(
}
threadCurrentRequest.set(originalRequest)
callback.fun()
callback.fun(requestLocal)
} catch {
case e: FatalExitError =>
completeShutdown()
@ -174,6 +180,7 @@ class KafkaRequestHandler(
private def completeShutdown(): Unit = {
requestLocal.close()
threadRequestChannel.remove()
shutdownComplete.countDown()
}

View File

@ -777,7 +777,6 @@ class ReplicaManager(val config: KafkaConfig,
transactionalId: String = null,
actionQueue: ActionQueue = this.actionQueue): Unit = {
if (isValidRequiredAcks(requiredAcks)) {
val sTime = time.milliseconds
val verificationGuards: mutable.Map[TopicPartition, VerificationGuard] = mutable.Map[TopicPartition, VerificationGuard]()
val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition, errorsPerPartition) =
@ -791,96 +790,9 @@ class ReplicaManager(val config: KafkaConfig,
(verifiedEntries.toMap, unverifiedEntries.toMap, errorEntries.toMap)
}
def appendEntries(allEntries: Map[TopicPartition, MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = {
val verifiedEntries =
if (unverifiedEntries.isEmpty)
allEntries
else
allEntries.filter { case (tp, _) =>
!unverifiedEntries.contains(tp)
}
val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed,
origin, verifiedEntries, requiredAcks, requestLocal, verificationGuards.toMap)
debug("Produce to local log in %d ms".format(time.milliseconds - sTime))
val errorResults = (unverifiedEntries ++ errorsPerPartition).map {
case (topicPartition, error) =>
// translate transaction coordinator errors to known producer response errors
val customException =
error match {
case Errors.INVALID_TXN_STATE => Some(error.exception("Partition was not added to the transaction"))
case Errors.CONCURRENT_TRANSACTIONS |
Errors.COORDINATOR_LOAD_IN_PROGRESS |
Errors.COORDINATOR_NOT_AVAILABLE |
Errors.NOT_COORDINATOR => Some(new NotEnoughReplicasException(
s"Unable to verify the partition has been added to the transaction. Underlying error: ${error.toString}"))
case _ => None
}
topicPartition -> LogAppendResult(
LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
Some(customException.getOrElse(error.exception)),
hasCustomErrorMessage = customException.isDefined
)
}
val allResults = localProduceResults ++ errorResults
val produceStatus = allResults.map { case (topicPartition, result) =>
topicPartition -> ProducePartitionStatus(
result.info.lastOffset + 1, // required offset
new PartitionResponse(
result.error,
result.info.firstOffset,
result.info.lastOffset,
result.info.logAppendTime,
result.info.logStartOffset,
result.info.recordErrors,
result.errorMessage
)
) // response status
}
actionQueue.add {
() => allResults.foreach { case (topicPartition, result) =>
val requestKey = TopicPartitionOperationKey(topicPartition)
result.info.leaderHwChange match {
case LeaderHwChange.INCREASED =>
// some delayed operations may be unblocked after HW changed
delayedProducePurgatory.checkAndComplete(requestKey)
delayedFetchPurgatory.checkAndComplete(requestKey)
delayedDeleteRecordsPurgatory.checkAndComplete(requestKey)
case LeaderHwChange.SAME =>
// probably unblock some follower fetch requests since log end offset has been updated
delayedFetchPurgatory.checkAndComplete(requestKey)
case LeaderHwChange.NONE =>
// nothing
}
}
}
recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats })
if (delayedProduceRequestRequired(requiredAcks, allEntries, allResults)) {
// create delayed produce operation
val produceMetadata = ProduceMetadata(requiredAcks, produceStatus)
val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock)
// create a list of (topic, partition) pairs to use as keys for this delayed produce operation
val producerRequestKeys = allEntries.keys.map(TopicPartitionOperationKey(_)).toSeq
// try to complete the request immediately, otherwise put it into the purgatory
// this is because while the delayed produce operation is being created, new
// requests may arrive and hence make this operation completable.
delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)
} else {
// we can respond immediately
val produceResponseStatus = produceStatus.map { case (k, status) => k -> status.responseStatus }
responseCallback(produceResponseStatus)
}
}
if (notYetVerifiedEntriesPerPartition.isEmpty || addPartitionsToTxnManager.isEmpty) {
appendEntries(verifiedEntriesPerPartition)(Map.empty)
appendEntries(verifiedEntriesPerPartition, internalTopicsAllowed, origin, requiredAcks, verificationGuards.toMap,
errorsPerPartition, recordConversionStatsCallback, timeout, responseCallback, delayedProduceLock)(requestLocal, Map.empty)
} else {
// For unverified entries, send a request to verify. When verified, the append process will proceed via the callback.
// We verify above that all partitions use the same producer ID.
@ -890,7 +802,20 @@ class ReplicaManager(val config: KafkaConfig,
producerId = batchInfo.producerId,
producerEpoch = batchInfo.producerEpoch,
topicPartitions = notYetVerifiedEntriesPerPartition.keySet.toSeq,
callback = KafkaRequestHandler.wrap(appendEntries(entriesPerPartition)(_))
callback = KafkaRequestHandler.wrapAsyncCallback(
appendEntries(
entriesPerPartition,
internalTopicsAllowed,
origin,
requiredAcks,
verificationGuards.toMap,
errorsPerPartition,
recordConversionStatsCallback,
timeout,
responseCallback,
delayedProduceLock
),
requestLocal)
))
}
} else {
@ -908,6 +833,110 @@ class ReplicaManager(val config: KafkaConfig,
}
}
/*
* Note: This method can be used as a callback in a different request thread. Ensure that correct RequestLocal
* is passed when executing this method. Accessing non-thread-safe data structures should be avoided if possible.
*/
private def appendEntries(allEntries: Map[TopicPartition, MemoryRecords],
internalTopicsAllowed: Boolean,
origin: AppendOrigin,
requiredAcks: Short,
verificationGuards: Map[TopicPartition, VerificationGuard],
errorsPerPartition: Map[TopicPartition, Errors],
recordConversionStatsCallback: Map[TopicPartition, RecordConversionStats] => Unit,
timeout: Long,
responseCallback: Map[TopicPartition, PartitionResponse] => Unit,
delayedProduceLock: Option[Lock])
(requestLocal: RequestLocal, unverifiedEntries: Map[TopicPartition, Errors]): Unit = {
val sTime = time.milliseconds
val verifiedEntries =
if (unverifiedEntries.isEmpty)
allEntries
else
allEntries.filter { case (tp, _) =>
!unverifiedEntries.contains(tp)
}
val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed,
origin, verifiedEntries, requiredAcks, requestLocal, verificationGuards.toMap)
debug("Produce to local log in %d ms".format(time.milliseconds - sTime))
val errorResults = (unverifiedEntries ++ errorsPerPartition).map {
case (topicPartition, error) =>
// translate transaction coordinator errors to known producer response errors
val customException =
error match {
case Errors.INVALID_TXN_STATE => Some(error.exception("Partition was not added to the transaction"))
case Errors.CONCURRENT_TRANSACTIONS |
Errors.COORDINATOR_LOAD_IN_PROGRESS |
Errors.COORDINATOR_NOT_AVAILABLE |
Errors.NOT_COORDINATOR => Some(new NotEnoughReplicasException(
s"Unable to verify the partition has been added to the transaction. Underlying error: ${error.toString}"))
case _ => None
}
topicPartition -> LogAppendResult(
LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
Some(customException.getOrElse(error.exception)),
hasCustomErrorMessage = customException.isDefined
)
}
val allResults = localProduceResults ++ errorResults
val produceStatus = allResults.map { case (topicPartition, result) =>
topicPartition -> ProducePartitionStatus(
result.info.lastOffset + 1, // required offset
new PartitionResponse(
result.error,
result.info.firstOffset,
result.info.lastOffset,
result.info.logAppendTime,
result.info.logStartOffset,
result.info.recordErrors,
result.errorMessage
)
) // response status
}
actionQueue.add {
() =>
allResults.foreach { case (topicPartition, result) =>
val requestKey = TopicPartitionOperationKey(topicPartition)
result.info.leaderHwChange match {
case LeaderHwChange.INCREASED =>
// some delayed operations may be unblocked after HW changed
delayedProducePurgatory.checkAndComplete(requestKey)
delayedFetchPurgatory.checkAndComplete(requestKey)
delayedDeleteRecordsPurgatory.checkAndComplete(requestKey)
case LeaderHwChange.SAME =>
// probably unblock some follower fetch requests since log end offset has been updated
delayedFetchPurgatory.checkAndComplete(requestKey)
case LeaderHwChange.NONE =>
// nothing
}
}
}
recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats })
if (delayedProduceRequestRequired(requiredAcks, allEntries, allResults)) {
// create delayed produce operation
val produceMetadata = ProduceMetadata(requiredAcks, produceStatus)
val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock)
// create a list of (topic, partition) pairs to use as keys for this delayed produce operation
val producerRequestKeys = allEntries.keys.map(TopicPartitionOperationKey(_)).toSeq
// try to complete the request immediately, otherwise put it into the purgatory
// this is because while the delayed produce operation is being created, new
// requests may arrive and hence make this operation completable.
delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)
} else {
// we can respond immediately
val produceResponseStatus = produceStatus.map { case (k, status) => k -> status.responseStatus }
responseCallback(produceResponseStatus)
}
}
private def partitionEntriesForVerification(verificationGuards: mutable.Map[TopicPartition, VerificationGuard],
entriesPerPartition: Map[TopicPartition, MemoryRecords],
verifiedEntries: mutable.Map[TopicPartition, MemoryRecords],

View File

@ -24,7 +24,7 @@ import org.apache.kafka.common.network.{ClientInformation, ListenerName}
import org.apache.kafka.common.protocol.ApiKeys
import org.apache.kafka.common.requests.{RequestContext, RequestHeader}
import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
import org.apache.kafka.common.utils.{MockTime, Time}
import org.apache.kafka.common.utils.{BufferSupplier, MockTime, Time}
import org.apache.kafka.server.log.remote.storage.{RemoteLogManagerConfig, RemoteStorageMetrics}
import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue}
import org.junit.jupiter.api.Test
@ -32,7 +32,7 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import org.mockito.ArgumentMatchers
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{mock, when}
import org.mockito.Mockito.{mock, times, verify, when}
import java.net.InetAddress
import java.nio.ByteBuffer
@ -57,10 +57,12 @@ class KafkaRequestHandlerTest {
when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
time.sleep(2)
// Prepare the callback.
val callback = KafkaRequestHandler.wrap((ms: Int) => {
time.sleep(ms)
handler.stop()
})
val callback = KafkaRequestHandler.wrapAsyncCallback(
(reqLocal: RequestLocal, ms: Int) => {
time.sleep(ms)
handler.stop()
},
RequestLocal.NoCaching)
// Execute the callback asynchronously.
CompletableFuture.runAsync(() => callback(1))
request.apiLocalCompleteTimeNanos = time.nanoseconds
@ -94,9 +96,11 @@ class KafkaRequestHandlerTest {
when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
handledCount = handledCount + 1
// Prepare the callback.
val callback = KafkaRequestHandler.wrap((ms: Int) => {
handler.stop()
})
val callback = KafkaRequestHandler.wrapAsyncCallback(
(reqLocal: RequestLocal, ms: Int) => {
handler.stop()
},
RequestLocal.NoCaching)
// Execute the callback asynchronously.
CompletableFuture.runAsync(() => callback(1))
}
@ -111,6 +115,75 @@ class KafkaRequestHandlerTest {
assertEquals(1, tryCompleteActionCount)
}
@Test
def testHandlingCallbackOnNewThread(): Unit = {
val time = new MockTime()
val metrics = mock(classOf[RequestChannel.Metrics])
val apiHandler = mock(classOf[ApiRequestHandler])
val requestChannel = new RequestChannel(10, "", time, metrics)
val handler = new KafkaRequestHandler(0, 0, mock(classOf[Meter]), new AtomicInteger(1), requestChannel, apiHandler, time)
val originalRequestLocal = mock(classOf[RequestLocal])
var handledCount = 0
val request = makeRequest(time, metrics)
requestChannel.sendRequest(request)
when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
// Prepare the callback.
val callback = KafkaRequestHandler.wrapAsyncCallback(
(reqLocal: RequestLocal, ms: Int) => {
reqLocal.bufferSupplier.close()
handledCount = handledCount + 1
handler.stop()
},
originalRequestLocal)
// Execute the callback asynchronously.
CompletableFuture.runAsync(() => callback(1))
}
handler.run()
// Verify that we don't use the request local that we passed in.
verify(originalRequestLocal, times(0)).bufferSupplier
assertEquals(1, handledCount)
}
@Test
def testCallbackOnSameThread(): Unit = {
val time = new MockTime()
val metrics = mock(classOf[RequestChannel.Metrics])
val apiHandler = mock(classOf[ApiRequestHandler])
val requestChannel = new RequestChannel(10, "", time, metrics)
val handler = new KafkaRequestHandler(0, 0, mock(classOf[Meter]), new AtomicInteger(1), requestChannel, apiHandler, time)
val originalRequestLocal = mock(classOf[RequestLocal])
when(originalRequestLocal.bufferSupplier).thenReturn(BufferSupplier.create())
var handledCount = 0
val request = makeRequest(time, metrics)
requestChannel.sendRequest(request)
when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
// Prepare the callback.
val callback = KafkaRequestHandler.wrapAsyncCallback(
(reqLocal: RequestLocal, ms: Int) => {
reqLocal.bufferSupplier.close()
handledCount = handledCount + 1
handler.stop()
},
originalRequestLocal)
// Execute the callback before the request returns.
callback(1)
}
handler.run()
// Verify that we do use the request local that we passed in.
verify(originalRequestLocal, times(1)).bufferSupplier
assertEquals(1, handledCount)
}
@ParameterizedTest
@ValueSource(booleans = Array(true, false))