KAFKA-14869: Ignore unknown record types for coordinators (KIP-915, Part-1) (#13598)

This patch implemented the first part of KIP-915. It updates the group coordinator and the transaction coordinator to ignores unknown record types while loading their respective state from the partitions. This allows downgrades from future versions that will include new record types.

Reviewers: Alexandre Dupriez <alexandre.dupriez@gmail.com>, David Jacot <djacot@confluent.io>
This commit is contained in:
Jeff Kim 2023-04-21 12:28:20 -04:00 committed by GitHub
parent d62859274a
commit f5a5bc8418
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 184 additions and 46 deletions

View File

@ -648,7 +648,6 @@ class GroupMetadataManager(brokerId: Int,
if (batchBaseOffset.isEmpty)
batchBaseOffset = Some(record.offset)
GroupMetadataManager.readMessageKey(record.key) match {
case offsetKey: OffsetKey =>
if (isTxnOffsetCommit && !pendingOffsets.contains(batch.producerId))
pendingOffsets.put(batch.producerId, mutable.Map[GroupTopicPartition, CommitRecordMetadataAndOffset]())
@ -680,8 +679,10 @@ class GroupMetadataManager(brokerId: Int,
removedGroups.add(groupId)
}
case unknownKey =>
throw new IllegalStateException(s"Unexpected message key $unknownKey while loading offsets and group metadata")
case unknownKey: UnknownKey =>
warn(s"Unknown message key with version ${unknownKey.version}" +
s" while loading offsets and group metadata from $topicPartition. Ignoring it. " +
"It could be a left over from an aborted upgrade.")
}
}
}
@ -1150,7 +1151,9 @@ object GroupMetadataManager {
// version 2 refers to group metadata
val key = new GroupMetadataKeyData(new ByteBufferAccessor(buffer), version)
GroupMetadataKey(version, key.group)
} else throw new IllegalStateException(s"Unknown group metadata message version: $version")
} else {
UnknownKey(version)
}
}
/**
@ -1270,7 +1273,7 @@ object GroupMetadataManager {
GroupMetadataManager.readMessageKey(record.key) match {
case offsetKey: OffsetKey => parseOffsets(offsetKey, record.value)
case groupMetadataKey: GroupMetadataKey => parseGroupMetadata(groupMetadataKey, record.value)
case _ => throw new KafkaException("Failed to decode message using offset topic decoder (message had an invalid key)")
case unknownKey: UnknownKey => (Some(s"unknown::version=${unknownKey.version}"), None)
}
}
}
@ -1348,18 +1351,20 @@ case class GroupTopicPartition(group: String, topicPartition: TopicPartition) {
"[%s,%s,%d]".format(group, topicPartition.topic, topicPartition.partition)
}
trait BaseKey{
sealed trait BaseKey{
def version: Short
def key: Any
}
case class OffsetKey(version: Short, key: GroupTopicPartition) extends BaseKey {
override def toString: String = key.toString
}
case class GroupMetadataKey(version: Short, key: String) extends BaseKey {
override def toString: String = key
}
case class UnknownKey(version: Short) extends BaseKey {
override def key: String = null
override def toString: String = key
}

View File

@ -19,7 +19,6 @@ package kafka.coordinator.transaction
import java.io.PrintStream
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import kafka.internals.generated.{TransactionLogKey, TransactionLogValue}
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil}
@ -98,7 +97,7 @@ object TransactionLog {
*
* @return the key
*/
def readTxnRecordKey(buffer: ByteBuffer): TxnKey = {
def readTxnRecordKey(buffer: ByteBuffer): BaseKey = {
val version = buffer.getShort
if (version >= TransactionLogKey.LOWEST_SUPPORTED_VERSION && version <= TransactionLogKey.HIGHEST_SUPPORTED_VERSION) {
val value = new TransactionLogKey(new ByteBufferAccessor(buffer), version)
@ -106,7 +105,9 @@ object TransactionLog {
version = version,
transactionalId = value.transactionalId
)
} else throw new IllegalStateException(s"Unknown version $version from the transaction log message")
} else {
UnknownKey(version)
}
}
/**
@ -148,17 +149,21 @@ object TransactionLog {
// Formatter for use with tools to read transaction log messages
class TransactionLogMessageFormatter extends MessageFormatter {
def writeTo(consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]], output: PrintStream): Unit = {
Option(consumerRecord.key).map(key => readTxnRecordKey(ByteBuffer.wrap(key))).foreach { txnKey =>
val transactionalId = txnKey.transactionalId
val value = consumerRecord.value
val producerIdMetadata = if (value == null)
None
else
readTxnRecordValue(transactionalId, ByteBuffer.wrap(value))
output.write(transactionalId.getBytes(StandardCharsets.UTF_8))
output.write("::".getBytes(StandardCharsets.UTF_8))
output.write(producerIdMetadata.getOrElse("NULL").toString.getBytes(StandardCharsets.UTF_8))
output.write("\n".getBytes(StandardCharsets.UTF_8))
Option(consumerRecord.key).map(key => readTxnRecordKey(ByteBuffer.wrap(key))).foreach {
case txnKey: TxnKey =>
val transactionalId = txnKey.transactionalId
val value = consumerRecord.value
val producerIdMetadata = if (value == null)
None
else
readTxnRecordValue(transactionalId, ByteBuffer.wrap(value))
output.write(transactionalId.getBytes(StandardCharsets.UTF_8))
output.write("::".getBytes(StandardCharsets.UTF_8))
output.write(producerIdMetadata.getOrElse("NULL").toString.getBytes(StandardCharsets.UTF_8))
output.write("\n".getBytes(StandardCharsets.UTF_8))
case unknownKey: UnknownKey =>
output.write(s"unknown::version=${unknownKey.version}\n".getBytes(StandardCharsets.UTF_8))
}
}
}
@ -167,25 +172,41 @@ object TransactionLog {
* Exposed for printing records using [[kafka.tools.DumpLogSegments]]
*/
def formatRecordKeyAndValue(record: Record): (Option[String], Option[String]) = {
val txnKey = TransactionLog.readTxnRecordKey(record.key)
val keyString = s"transaction_metadata::transactionalId=${txnKey.transactionalId}"
TransactionLog.readTxnRecordKey(record.key) match {
case txnKey: TxnKey =>
val keyString = s"transaction_metadata::transactionalId=${txnKey.transactionalId}"
val valueString = TransactionLog.readTxnRecordValue(txnKey.transactionalId, record.value) match {
case None => "<DELETE>"
val valueString = TransactionLog.readTxnRecordValue(txnKey.transactionalId, record.value) match {
case None => "<DELETE>"
case Some(txnMetadata) => s"producerId:${txnMetadata.producerId}," +
s"producerEpoch:${txnMetadata.producerEpoch}," +
s"state=${txnMetadata.state}," +
s"partitions=${txnMetadata.topicPartitions.mkString("[", ",", "]")}," +
s"txnLastUpdateTimestamp=${txnMetadata.txnLastUpdateTimestamp}," +
s"txnTimeoutMs=${txnMetadata.txnTimeoutMs}"
case Some(txnMetadata) => s"producerId:${txnMetadata.producerId}," +
s"producerEpoch:${txnMetadata.producerEpoch}," +
s"state=${txnMetadata.state}," +
s"partitions=${txnMetadata.topicPartitions.mkString("[", ",", "]")}," +
s"txnLastUpdateTimestamp=${txnMetadata.txnLastUpdateTimestamp}," +
s"txnTimeoutMs=${txnMetadata.txnTimeoutMs}"
}
(Some(keyString), Some(valueString))
case unknownKey: UnknownKey =>
(Some(s"unknown::version=${unknownKey.version}"), None)
}
(Some(keyString), Some(valueString))
}
}
case class TxnKey(version: Short, transactionalId: String) {
sealed trait BaseKey{
def version: Short
def transactionalId: String
}
case class TxnKey(version: Short, transactionalId: String) extends BaseKey {
override def toString: String = transactionalId
}
case class UnknownKey(version: Short) extends BaseKey {
override def transactionalId: String = null
override def toString: String = transactionalId
}

View File

@ -466,16 +466,23 @@ class TransactionStateManager(brokerId: Int,
memRecords.batches.forEach { batch =>
for (record <- batch.asScala) {
require(record.hasKey, "Transaction state log's key should not be null")
val txnKey = TransactionLog.readTxnRecordKey(record.key)
// load transaction metadata along with transaction state
val transactionalId = txnKey.transactionalId
TransactionLog.readTxnRecordValue(transactionalId, record.value) match {
case None =>
loadedTransactions.remove(transactionalId)
case Some(txnMetadata) =>
loadedTransactions.put(transactionalId, txnMetadata)
TransactionLog.readTxnRecordKey(record.key) match {
case txnKey: TxnKey =>
// load transaction metadata along with transaction state
val transactionalId = txnKey.transactionalId
TransactionLog.readTxnRecordValue(transactionalId, record.value) match {
case None =>
loadedTransactions.remove(transactionalId)
case Some(txnMetadata) =>
loadedTransactions.put(transactionalId, txnMetadata)
}
currOffset = batch.nextOffset
case unknownKey: UnknownKey =>
warn(s"Unknown message key with version ${unknownKey.version}" +
s" while loading transaction state from $topicPartition. Ignoring it. " +
"It could be a left over from an aborted upgrade.")
}
currOffset = batch.nextOffset
}
}
}

View File

@ -37,7 +37,7 @@ import org.apache.kafka.clients.consumer.internals.ConsumerProtocol
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.internals.Topic
import org.apache.kafka.common.metrics.{JmxReporter, KafkaMetricsContext, Metrics => kMetrics}
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.protocol.{Errors, MessageUtil}
import org.apache.kafka.common.record._
import org.apache.kafka.common.requests.OffsetFetchResponse
import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
@ -637,6 +637,7 @@ class GroupMetadataManagerTest {
val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets)
val memberId = "98098230493"
val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId)
val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE,
(offsetCommitRecords ++ Seq(groupMetadataRecord)).toArray: _*)
@ -2634,4 +2635,60 @@ class GroupMetadataManagerTest {
assertTrue(partitionLoadTime("partition-load-time-max") >= diff)
assertTrue(partitionLoadTime("partition-load-time-avg") >= diff)
}
@Test
def testReadMessageKeyCanReadUnknownMessage(): Unit = {
val record = new kafka.internals.generated.GroupMetadataKey()
val unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, record)
val key = GroupMetadataManager.readMessageKey(ByteBuffer.wrap(unknownRecord))
assertEquals(UnknownKey(Short.MaxValue), key)
}
@Test
def testLoadGroupsAndOffsetsWillIgnoreUnknownMessage(): Unit = {
val generation = 935
val protocolType = "consumer"
val protocol = "range"
val startOffset = 15L
val committedOffsets = Map(
new TopicPartition("foo", 0) -> 23L,
new TopicPartition("foo", 1) -> 455L,
new TopicPartition("bar", 0) -> 8992L
)
val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets)
val memberId = "98098230493"
val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId)
// Should ignore unknown record
val unknownKey = new kafka.internals.generated.GroupMetadataKey()
val lowestUnsupportedVersion = (kafka.internals.generated.GroupMetadataKey
.HIGHEST_SUPPORTED_VERSION + 1).toShort
val unknownMessage1 = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, unknownKey)
val unknownMessage2 = MessageUtil.toVersionPrefixedBytes(lowestUnsupportedVersion, unknownKey)
val unknownRecord1 = new SimpleRecord(unknownMessage1, unknownMessage1)
val unknownRecord2 = new SimpleRecord(unknownMessage2, unknownMessage2)
val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE,
(offsetCommitRecords ++ Seq(unknownRecord1, unknownRecord2) ++ Seq(groupMetadataRecord)).toArray: _*)
expectGroupMetadataLoad(groupTopicPartition, startOffset, records)
groupMetadataManager.loadGroupsAndOffsets(groupTopicPartition, 1, _ => (), 0L)
val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache"))
assertEquals(groupId, group.groupId)
assertEquals(Stable, group.currentState)
assertEquals(memberId, group.leaderOrNull)
assertEquals(generation, group.generationId)
assertEquals(Some(protocolType), group.protocolType)
assertEquals(protocol, group.protocolName.orNull)
assertEquals(Set(memberId), group.allMembers)
assertEquals(committedOffsets.size, group.allOffsets.size)
committedOffsets.foreach { case (topicPartition, offset) =>
assertEquals(Some(offset), group.offset(topicPartition).map(_.offset))
assertTrue(group.offset(topicPartition).map(_.expireTimestamp).contains(None))
}
}
}

View File

@ -17,12 +17,15 @@
package kafka.coordinator.transaction
import kafka.internals.generated.TransactionLogKey
import kafka.utils.TestUtils
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.protocol.MessageUtil
import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord}
import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows}
import org.junit.jupiter.api.Test
import java.nio.ByteBuffer
import scala.jdk.CollectionConverters._
class TransactionLogTest {
@ -135,4 +138,11 @@ class TransactionLogTest {
assertEquals(Some("<DELETE>"), valueStringOpt)
}
@Test
def testReadTxnRecordKeyCanReadUnknownMessage(): Unit = {
val record = new TransactionLogKey()
val unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, record)
val key = TransactionLog.readTxnRecordKey(ByteBuffer.wrap(unknownRecord))
assertEquals(UnknownKey(Short.MaxValue), key)
}
}

View File

@ -16,6 +16,8 @@
*/
package kafka.coordinator.transaction
import kafka.internals.generated.TransactionLogKey
import java.lang.management.ManagementFactory
import java.nio.ByteBuffer
import java.util.concurrent.CountDownLatch
@ -29,7 +31,7 @@ import kafka.zk.KafkaZkClient
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME
import org.apache.kafka.common.metrics.{JmxReporter, KafkaMetricsContext, Metrics}
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.protocol.{Errors, MessageUtil}
import org.apache.kafka.common.record._
import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
import org.apache.kafka.common.requests.TransactionResult
@ -1133,4 +1135,40 @@ class TransactionStateManagerTest {
assertTrue(partitionLoadTime("partition-load-time-max") >= 0)
assertTrue(partitionLoadTime( "partition-load-time-avg") >= 0)
}
@Test
def testIgnoreUnknownRecordType(): Unit = {
txnMetadata1.state = PrepareCommit
txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit()))
val startOffset = 0L
val unknownKey = new TransactionLogKey()
val unknownMessage = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, unknownKey)
val unknownRecord = new SimpleRecord(unknownMessage, unknownMessage)
val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE,
(Seq(unknownRecord) ++ txnRecords).toArray: _*)
prepareTxnLog(topicPartition, 0, records)
transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch = 1, (_, _, _, _) => ())
assertEquals(0, transactionManager.loadingPartitions.size)
assertTrue(transactionManager.transactionMetadataCache.contains(partitionId))
val txnMetadataPool = transactionManager.transactionMetadataCache(partitionId).metadataPerTransactionalId
assertFalse(txnMetadataPool.isEmpty)
assertTrue(txnMetadataPool.contains(transactionalId1))
val txnMetadata = txnMetadataPool.get(transactionalId1)
assertEquals(txnMetadata1.transactionalId, txnMetadata.transactionalId)
assertEquals(txnMetadata1.producerId, txnMetadata.producerId)
assertEquals(txnMetadata1.lastProducerId, txnMetadata.lastProducerId)
assertEquals(txnMetadata1.producerEpoch, txnMetadata.producerEpoch)
assertEquals(txnMetadata1.lastProducerEpoch, txnMetadata.lastProducerEpoch)
assertEquals(txnMetadata1.txnTimeoutMs, txnMetadata.txnTimeoutMs)
assertEquals(txnMetadata1.state, txnMetadata.state)
assertEquals(txnMetadata1.topicPartitions, txnMetadata.topicPartitions)
assertEquals(1, transactionManager.transactionMetadataCache(partitionId).coordinatorEpoch)
}
}