diff --git a/checkstyle/import-control-transaction-coordinator.xml b/checkstyle/import-control-transaction-coordinator.xml
index a48100a9acc..810c127c95c 100644
--- a/checkstyle/import-control-transaction-coordinator.xml
+++ b/checkstyle/import-control-transaction-coordinator.xml
@@ -40,6 +40,7 @@
+
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
index a3e9eacb66f..f024e88aa8e 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala
@@ -52,7 +52,7 @@ object TransactionLog {
*
* @return key bytes
*/
- private[transaction] def keyToBytes(transactionalId: String): Array[Byte] = {
+ def keyToBytes(transactionalId: String): Array[Byte] = {
MessageUtil.toCoordinatorTypePrefixedBytes(new TransactionLogKey().setTransactionalId(transactionalId))
}
@@ -61,7 +61,7 @@ object TransactionLog {
*
* @return value payload bytes
*/
- private[transaction] def valueToBytes(txnMetadata: TxnTransitMetadata,
+ def valueToBytes(txnMetadata: TxnTransitMetadata,
transactionVersionLevel: TransactionVersion): Array[Byte] = {
if (txnMetadata.txnState == TransactionState.EMPTY && !txnMetadata.topicPartitions.isEmpty)
throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata")
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
deleted file mode 100644
index 46116be6d7c..00000000000
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala
+++ /dev/null
@@ -1,246 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package kafka.coordinator.transaction
-
-import org.apache.kafka.common.TopicPartition
-import org.apache.kafka.common.compress.Compression
-import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil}
-import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection
-import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type}
-import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, SimpleRecord}
-import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState, TxnTransitMetadata}
-import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue}
-import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2}
-import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue, fail}
-import org.junit.jupiter.api.Test
-
-import java.nio.ByteBuffer
-import java.util
-import scala.collection.Seq
-import scala.jdk.CollectionConverters._
-
-class TransactionLogTest {
-
- val producerEpoch: Short = 0
- val transactionTimeoutMs: Int = 1000
-
- val topicPartitions = util.Set.of(new TopicPartition("topic1", 0),
- new TopicPartition("topic1", 1),
- new TopicPartition("topic2", 0),
- new TopicPartition("topic2", 1),
- new TopicPartition("topic2", 2))
-
- @Test
- def shouldThrowExceptionWriteInvalidTxn(): Unit = {
- val transactionalId = "transactionalId"
- val producerId = 23423L
-
- val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
- RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, util.Set.of, 0, 0, TV_0)
- txnMetadata.addPartitions(topicPartitions)
-
- assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2))
- }
-
- @Test
- def shouldReadWriteMessages(): Unit = {
- val pidMappings = Map[String, Long]("zero" -> 0L,
- "one" -> 1L,
- "two" -> 2L,
- "three" -> 3L,
- "four" -> 4L,
- "five" -> 5L)
-
- val transactionStates = Map[Long, TransactionState](0L -> TransactionState.EMPTY,
- 1L -> TransactionState.ONGOING,
- 2L -> TransactionState.PREPARE_COMMIT,
- 3L -> TransactionState.COMPLETE_COMMIT,
- 4L -> TransactionState.PREPARE_ABORT,
- 5L -> TransactionState.COMPLETE_ABORT)
-
- // generate transaction log messages
- val txnRecords = pidMappings.map { case (transactionalId, producerId) =>
- val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch,
- RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), util.Set.of, 0, 0, TV_0)
-
- if (!txnMetadata.state.equals(TransactionState.EMPTY))
- txnMetadata.addPartitions(topicPartitions)
-
- val keyBytes = TransactionLog.keyToBytes(transactionalId)
- val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)
-
- new SimpleRecord(keyBytes, valueBytes)
- }.toSeq
-
- val records = MemoryRecords.withRecords(0, Compression.NONE, txnRecords: _*)
-
- var count = 0
- for (record <- records.records.asScala) {
- TransactionLog.readTxnRecordKey(record.key) match {
- case Left(version) => fail(s"Unexpected record version: $version")
- case Right(transactionalId) =>
- val txnMetadata = TransactionLog.readTxnRecordValue(transactionalId, record.value).get
-
- assertEquals(pidMappings(transactionalId), txnMetadata.producerId)
- assertEquals(producerEpoch, txnMetadata.producerEpoch)
- assertEquals(transactionTimeoutMs, txnMetadata.txnTimeoutMs)
- assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state)
-
- if (txnMetadata.state.equals(TransactionState.EMPTY))
- assertEquals(util.Set.of, txnMetadata.topicPartitions)
- else
- assertEquals(topicPartitions, txnMetadata.topicPartitions)
-
- count = count + 1
- }
- }
-
- assertEquals(pidMappings.size, count)
- }
-
- @Test
- def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = {
- val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, new util.HashSet(), 500, 500, TV_0)
- val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0))
- assertEquals(0, txnLogValueBuffer.getShort)
- }
-
- @Test
- def testSerializeTransactionLogValueToFlexibleVersion(): Unit = {
- val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, new util.HashSet(), 500, 500, TV_2)
- val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2))
- assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort)
- }
-
- @Test
- def testDeserializeHighestSupportedTransactionLogValue(): Unit = {
- val txnPartitions = new TransactionLogValue.PartitionsSchema()
- .setTopic("topic")
- .setPartitionIds(util.List.of(0))
-
- val txnLogValue = new TransactionLogValue()
- .setProducerId(100)
- .setProducerEpoch(50.toShort)
- .setTransactionStatus(TransactionState.COMPLETE_COMMIT.id)
- .setTransactionStartTimestampMs(750L)
- .setTransactionLastUpdateTimestampMs(1000L)
- .setTransactionTimeoutMs(500)
- .setTransactionPartitions(util.List.of(txnPartitions))
-
- val serialized = MessageUtil.toVersionPrefixedByteBuffer(1, txnLogValue)
- val deserialized = TransactionLog.readTxnRecordValue("transactionId", serialized).get
-
- assertEquals(100, deserialized.producerId)
- assertEquals(50, deserialized.producerEpoch)
- assertEquals(TransactionState.COMPLETE_COMMIT, deserialized.state)
- assertEquals(750L, deserialized.txnStartTimestamp)
- assertEquals(1000L, deserialized.txnLastUpdateTimestamp)
- assertEquals(500, deserialized.txnTimeoutMs)
-
- val actualTxnPartitions = deserialized.topicPartitions
- assertEquals(1, actualTxnPartitions.size)
- assertTrue(actualTxnPartitions.contains(new TopicPartition("topic", 0)))
- }
-
- @Test
- def testDeserializeFutureTransactionLogValue(): Unit = {
- // Copy of TransactionLogValue.PartitionsSchema.SCHEMA_1 with a few
- // additional tagged fields.
- val futurePartitionsSchema = new Schema(
- new Field("topic", Type.COMPACT_STRING, ""),
- new Field("partition_ids", new CompactArrayOf(Type.INT32), ""),
- TaggedFieldsSection.of(
- Int.box(100), new Field("partition_foo", Type.STRING, ""),
- Int.box(101), new Field("partition_foo", Type.INT32, "")
- )
- )
-
- // Create TransactionLogValue.PartitionsSchema with tagged fields
- val txnPartitions = new Struct(futurePartitionsSchema)
- txnPartitions.set("topic", "topic")
- txnPartitions.set("partition_ids", Array(Integer.valueOf(1)))
- val txnPartitionsTaggedFields = new util.TreeMap[Integer, Any]()
- txnPartitionsTaggedFields.put(100, "foo")
- txnPartitionsTaggedFields.put(101, 4000)
- txnPartitions.set("_tagged_fields", txnPartitionsTaggedFields)
-
- // Copy of TransactionLogValue.SCHEMA_1 with a few
- // additional tagged fields.
- val futureTransactionLogValueSchema = new Schema(
- new Field("producer_id", Type.INT64, ""),
- new Field("producer_epoch", Type.INT16, ""),
- new Field("transaction_timeout_ms", Type.INT32, ""),
- new Field("transaction_status", Type.INT8, ""),
- new Field("transaction_partitions", CompactArrayOf.nullable(futurePartitionsSchema), ""),
- new Field("transaction_last_update_timestamp_ms", Type.INT64, ""),
- new Field("transaction_start_timestamp_ms", Type.INT64, ""),
- TaggedFieldsSection.of(
- Int.box(100), new Field("txn_foo", Type.STRING, ""),
- Int.box(101), new Field("txn_bar", Type.INT32, "")
- )
- )
-
- // Create TransactionLogValue with tagged fields
- val transactionLogValue = new Struct(futureTransactionLogValueSchema)
- transactionLogValue.set("producer_id", 1000L)
- transactionLogValue.set("producer_epoch", 100.toShort)
- transactionLogValue.set("transaction_timeout_ms", 1000)
- transactionLogValue.set("transaction_status", TransactionState.COMPLETE_COMMIT.id)
- transactionLogValue.set("transaction_partitions", Array(txnPartitions))
- transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L)
- transactionLogValue.set("transaction_start_timestamp_ms", 3000L)
- val txnLogValueTaggedFields = new util.TreeMap[Integer, Any]()
- txnLogValueTaggedFields.put(100, "foo")
- txnLogValueTaggedFields.put(101, 4000)
- transactionLogValue.set("_tagged_fields", txnLogValueTaggedFields)
-
- // Prepare the buffer.
- val buffer = ByteBuffer.allocate(transactionLogValue.sizeOf() + 2)
- buffer.put(0.toByte)
- buffer.put(1.toByte) // Add 1 as version.
- transactionLogValue.writeTo(buffer)
- buffer.flip()
-
- // Read the buffer with the real schema and verify that tagged
- // fields were read but ignored.
- buffer.getShort() // Skip version.
- val value = new TransactionLogValue(new ByteBufferAccessor(buffer), 1.toShort)
- assertEquals(Seq(100, 101), value.unknownTaggedFields().asScala.map(_.tag))
- assertEquals(Seq(100, 101), value.transactionPartitions().get(0).unknownTaggedFields().asScala.map(_.tag))
-
- // Read the buffer with readTxnRecordValue.
- buffer.rewind()
- val txnMetadata = TransactionLog.readTxnRecordValue("transaction-id", buffer).get
- assertEquals(1000L, txnMetadata.producerId)
- assertEquals(100, txnMetadata.producerEpoch)
- assertEquals(1000L, txnMetadata.txnTimeoutMs)
- assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata.state)
- assertEquals(util.Set.of(new TopicPartition("topic", 1)), txnMetadata.topicPartitions)
- assertEquals(2000L, txnMetadata.txnLastUpdateTimestamp)
- assertEquals(3000L, txnMetadata.txnStartTimestamp)
- }
-
- @Test
- def testReadTxnRecordKeyCanReadUnknownMessage(): Unit = {
- val record = new TransactionLogKey()
- val unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, record)
- TransactionLog.readTxnRecordKey(ByteBuffer.wrap(unknownRecord)) match {
- case Left(version) => assertEquals(Short.MaxValue, version)
- case Right(_) => fail("Expected to read unknown message")
- }
- }
-}
diff --git a/transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionLogTest.java b/transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionLogTest.java
new file mode 100644
index 00000000000..3295901a8ba
--- /dev/null
+++ b/transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionLogTest.java
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.coordinator.transaction;
+
+import kafka.coordinator.transaction.TransactionLog;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.compress.Compression;
+import org.apache.kafka.common.protocol.ByteBufferAccessor;
+import org.apache.kafka.common.protocol.MessageUtil;
+import org.apache.kafka.common.protocol.types.CompactArrayOf;
+import org.apache.kafka.common.protocol.types.Field;
+import org.apache.kafka.common.protocol.types.RawTaggedField;
+import org.apache.kafka.common.protocol.types.Schema;
+import org.apache.kafka.common.protocol.types.Struct;
+import org.apache.kafka.common.protocol.types.Type;
+import org.apache.kafka.common.record.MemoryRecords;
+import org.apache.kafka.common.record.RecordBatch;
+import org.apache.kafka.common.record.SimpleRecord;
+import org.apache.kafka.coordinator.transaction.generated.TransactionLogKey;
+import org.apache.kafka.coordinator.transaction.generated.TransactionLogValue;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import java.nio.ByteBuffer;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.stream.Stream;
+
+import static java.nio.ByteBuffer.allocate;
+import static java.nio.ByteBuffer.wrap;
+import static org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection;
+import static org.apache.kafka.server.common.TransactionVersion.LATEST_PRODUCTION;
+import static org.apache.kafka.server.common.TransactionVersion.TV_0;
+import static org.apache.kafka.server.common.TransactionVersion.TV_2;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+class TransactionLogTest {
+
+ private final Set topicPartitions = Set.of(
+ new TopicPartition("topic1", 0),
+ new TopicPartition("topic1", 1),
+ new TopicPartition("topic2", 0),
+ new TopicPartition("topic2", 1),
+ new TopicPartition("topic2", 2)
+ );
+
+ private sealed interface TxnKeyResult {
+ record UnknownVersion(short version) implements TxnKeyResult { }
+ record TransactionalId(String id) implements TxnKeyResult { }
+ }
+
+ private static TxnKeyResult readTxnRecordKey(ByteBuffer buf) {
+ var e = TransactionLog.readTxnRecordKey(buf);
+ return e.isLeft()
+ ? new TxnKeyResult.UnknownVersion((Short) e.left().get())
+ : new TxnKeyResult.TransactionalId(e.right().get());
+ }
+
+ private static TransactionMetadata TransactionMetadata(TransactionState state) {
+ return new TransactionMetadata(
+ "transactionalId",
+ 0L,
+ RecordBatch.NO_PRODUCER_ID,
+ RecordBatch.NO_PRODUCER_ID,
+ (short) 0,
+ RecordBatch.NO_PRODUCER_EPOCH,
+ 1000,
+ state,
+ Set.of(),
+ 0,
+ 0,
+ LATEST_PRODUCTION
+ );
+ }
+
+ private static Stream transactionStatesProvider() {
+ return Stream.of(
+ TransactionState.EMPTY,
+ TransactionState.ONGOING,
+ TransactionState.PREPARE_COMMIT,
+ TransactionState.COMPLETE_COMMIT,
+ TransactionState.PREPARE_ABORT,
+ TransactionState.COMPLETE_ABORT
+ );
+ }
+
+ @Test
+ void shouldThrowExceptionWriteInvalidTxn() {
+ var txnMetadata = TransactionMetadata(TransactionState.EMPTY);
+ txnMetadata.addPartitions(topicPartitions);
+
+ var preparedMetadata = txnMetadata.prepareNoTransit();
+ assertThrows(IllegalStateException.class, () -> TransactionLog.valueToBytes(preparedMetadata, TV_2));
+ }
+
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("transactionStatesProvider")
+ void shouldReadWriteMessages(TransactionState state) {
+ var txnMetadata = TransactionMetadata(state);
+ if (state != TransactionState.EMPTY) {
+ txnMetadata.addPartitions(topicPartitions);
+ }
+
+ var record = MemoryRecords.withRecords(Compression.NONE, new SimpleRecord(
+ TransactionLog.keyToBytes(txnMetadata.transactionalId()),
+ TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)
+ )).records().iterator().next();
+ var txnIdResult = assertInstanceOf(TxnKeyResult.TransactionalId.class, readTxnRecordKey(record.key()));
+ var deserialized = TransactionLog.readTxnRecordValue(txnIdResult.id(), record.value()).get();
+
+ assertEquals(txnMetadata.producerId(), deserialized.producerId());
+ assertEquals(txnMetadata.producerEpoch(), deserialized.producerEpoch());
+ assertEquals(txnMetadata.txnTimeoutMs(), deserialized.txnTimeoutMs());
+ assertEquals(txnMetadata.state(), deserialized.state());
+
+ if (txnMetadata.state() == TransactionState.EMPTY) {
+ assertEquals(Set.of(), deserialized.topicPartitions());
+ } else {
+ assertEquals(topicPartitions, deserialized.topicPartitions());
+ }
+ }
+
+ @Test
+ void testSerializeTransactionLogValueToHighestNonFlexibleVersion() {
+ var txnTransitMetadata = new TxnTransitMetadata(1L, 1L, 1L, (short) 1, (short) 1, 1000, TransactionState.COMPLETE_COMMIT, new HashSet<>(), 500L, 500L, TV_0);
+ var txnLogValueBuffer = wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0));
+ assertEquals(TV_0.transactionLogValueVersion(), txnLogValueBuffer.getShort());
+ }
+
+ @Test
+ void testSerializeTransactionLogValueToFlexibleVersion() {
+ var txnTransitMetadata = new TxnTransitMetadata(1L, 1L, 1L, (short) 1, (short) 1, 1000, TransactionState.COMPLETE_COMMIT, new HashSet<>(), 500L, 500L, TV_2);
+ var txnLogValueBuffer = wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2));
+ assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort());
+ }
+
+ @Test
+ void testDeserializeHighestSupportedTransactionLogValue() {
+ var txnPartitions = new TransactionLogValue.PartitionsSchema()
+ .setTopic("topic")
+ .setPartitionIds(List.of(0));
+
+ var txnLogValue = new TransactionLogValue()
+ .setProducerId(100)
+ .setProducerEpoch((short) 50)
+ .setTransactionStatus(TransactionState.COMPLETE_COMMIT.id())
+ .setTransactionStartTimestampMs(750L)
+ .setTransactionLastUpdateTimestampMs(1000L)
+ .setTransactionTimeoutMs(500)
+ .setTransactionPartitions(List.of(txnPartitions));
+
+ var serialized = MessageUtil.toVersionPrefixedByteBuffer((short) 1, txnLogValue);
+ var deserialized = TransactionLog.readTxnRecordValue("transactionId", serialized).get();
+
+ assertEquals(100, deserialized.producerId());
+ assertEquals(50, deserialized.producerEpoch());
+ assertEquals(TransactionState.COMPLETE_COMMIT, deserialized.state());
+ assertEquals(750L, deserialized.txnStartTimestamp());
+ assertEquals(1000L, deserialized.txnLastUpdateTimestamp());
+ assertEquals(500, deserialized.txnTimeoutMs());
+
+ var actualTxnPartitions = deserialized.topicPartitions();
+ assertEquals(1, actualTxnPartitions.size());
+ assertTrue(actualTxnPartitions.contains(new TopicPartition("topic", 0)));
+ }
+
+ @Test
+ void testDeserializeFutureTransactionLogValue() {
+ // Copy of TransactionLogValue.PartitionsSchema.SCHEMA_1 with a few
+ // additional tagged fields.
+ var futurePartitionsSchema = new Schema(
+ new Field("topic", Type.COMPACT_STRING, ""),
+ new Field("partition_ids", new CompactArrayOf(Type.INT32), ""),
+ TaggedFieldsSection.of(
+ 100, new Field("partition_foo", Type.STRING, ""),
+ 101, new Field("partition_foo", Type.INT32, "")
+ )
+ );
+
+ // Create TransactionLogValue.PartitionsSchema with tagged fields
+ var txnPartitions = new Struct(futurePartitionsSchema);
+ txnPartitions.set("topic", "topic");
+ txnPartitions.set("partition_ids", new Integer[]{1});
+ var txnPartitionsTaggedFields = new TreeMap();
+ txnPartitionsTaggedFields.put(100, "foo");
+ txnPartitionsTaggedFields.put(101, 4000);
+ txnPartitions.set("_tagged_fields", txnPartitionsTaggedFields);
+
+ // Copy of TransactionLogValue.SCHEMA_1 with a few
+ // additional tagged fields.
+ var futureTransactionLogValueSchema = new Schema(
+ new Field("producer_id", Type.INT64, ""),
+ new Field("producer_epoch", Type.INT16, ""),
+ new Field("transaction_timeout_ms", Type.INT32, ""),
+ new Field("transaction_status", Type.INT8, ""),
+ new Field("transaction_partitions", CompactArrayOf.nullable(futurePartitionsSchema), ""),
+ new Field("transaction_last_update_timestamp_ms", Type.INT64, ""),
+ new Field("transaction_start_timestamp_ms", Type.INT64, ""),
+ TaggedFieldsSection.of(
+ 100, new Field("txn_foo", Type.STRING, ""),
+ 101, new Field("txn_bar", Type.INT32, "")
+ )
+ );
+
+ // Create TransactionLogValue with tagged fields
+ var transactionLogValue = new Struct(futureTransactionLogValueSchema);
+ transactionLogValue.set("producer_id", 1000L);
+ transactionLogValue.set("producer_epoch", (short) 100);
+ transactionLogValue.set("transaction_timeout_ms", 1000);
+ transactionLogValue.set("transaction_status", TransactionState.COMPLETE_COMMIT.id());
+ transactionLogValue.set("transaction_partitions", new Struct[]{txnPartitions});
+ transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L);
+ transactionLogValue.set("transaction_start_timestamp_ms", 3000L);
+ var txnLogValueTaggedFields = new TreeMap();
+ txnLogValueTaggedFields.put(100, "foo");
+ txnLogValueTaggedFields.put(101, 4000);
+ transactionLogValue.set("_tagged_fields", txnLogValueTaggedFields);
+
+ // Prepare the buffer.
+ var buffer = allocate(Short.BYTES + transactionLogValue.sizeOf());
+ buffer.putShort((short) 1); // Add 1 as version.
+ transactionLogValue.writeTo(buffer);
+ buffer.flip();
+
+ // Read the buffer with the real schema and verify that tagged
+ // fields were read but ignored.
+ buffer.getShort(); // Skip version.
+ var value = new TransactionLogValue(new ByteBufferAccessor(buffer), (short) 1);
+ assertEquals(List.of(100, 101), value.unknownTaggedFields().stream().map(RawTaggedField::tag).toList());
+ assertEquals(List.of(100, 101), value.transactionPartitions().get(0).unknownTaggedFields().stream().map(RawTaggedField::tag).toList());
+
+ // Read the buffer with readTxnRecordValue.
+ buffer.rewind();
+ var txnMetadata = TransactionLog.readTxnRecordValue("transaction-id", buffer);
+
+ assertFalse(txnMetadata.isEmpty(), "Expected transaction metadata but got none");
+
+ var metadata = txnMetadata.get();
+ assertEquals(1000L, metadata.producerId());
+ assertEquals(100, metadata.producerEpoch());
+ assertEquals(1000, metadata.txnTimeoutMs());
+ assertEquals(TransactionState.COMPLETE_COMMIT, metadata.state());
+ assertEquals(Set.of(new TopicPartition("topic", 1)), metadata.topicPartitions());
+ assertEquals(2000L, metadata.txnLastUpdateTimestamp());
+ assertEquals(3000L, metadata.txnStartTimestamp());
+ }
+
+ @Test
+ void testReadTxnRecordKeyCanReadUnknownMessage() {
+ var unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MAX_VALUE, new TransactionLogKey());
+ var result = readTxnRecordKey(wrap(unknownRecord));
+
+ var uv = assertInstanceOf(TxnKeyResult.UnknownVersion.class, result);
+ assertEquals(Short.MAX_VALUE, uv.version());
+ }
+
+ @Test
+ void shouldReturnEmptyWhenForTombstoneRecord() {
+ assertTrue(TransactionLog.readTxnRecordValue("transaction-id", null).isEmpty());
+ }
+}
\ No newline at end of file