From 749c2d91d52036f789cacc7ee0e04d0bcada6813 Mon Sep 17 00:00:00 2001 From: Hong-Yi Chen Date: Mon, 15 Sep 2025 11:25:54 +0800 Subject: [PATCH] KAFKA-19609 Move TransactionLogTest to transaction-coordinator module (#20460) This PR migrates the `TransactionLogTest` from Scala to Java for better consistency with the rest of the test suite and to simplify future maintenance. Reviewers: Chia-Ping Tsai --- ...import-control-transaction-coordinator.xml | 1 + .../transaction/TransactionLog.scala | 4 +- .../transaction/TransactionLogTest.scala | 246 --------------- .../transaction/TransactionLogTest.java | 283 ++++++++++++++++++ 4 files changed, 286 insertions(+), 248 deletions(-) delete mode 100644 core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala create mode 100644 transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionLogTest.java diff --git a/checkstyle/import-control-transaction-coordinator.xml b/checkstyle/import-control-transaction-coordinator.xml index a48100a9acc..810c127c95c 100644 --- a/checkstyle/import-control-transaction-coordinator.xml +++ b/checkstyle/import-control-transaction-coordinator.xml @@ -40,6 +40,7 @@ + diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala index a3e9eacb66f..f024e88aa8e 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala @@ -52,7 +52,7 @@ object TransactionLog { * * @return key bytes */ - private[transaction] def keyToBytes(transactionalId: String): Array[Byte] = { + def keyToBytes(transactionalId: String): Array[Byte] = { MessageUtil.toCoordinatorTypePrefixedBytes(new TransactionLogKey().setTransactionalId(transactionalId)) } @@ -61,7 +61,7 @@ object TransactionLog { * * @return value payload bytes */ - private[transaction] def valueToBytes(txnMetadata: TxnTransitMetadata, + def valueToBytes(txnMetadata: TxnTransitMetadata, transactionVersionLevel: TransactionVersion): Array[Byte] = { if (txnMetadata.txnState == TransactionState.EMPTY && !txnMetadata.topicPartitions.isEmpty) throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata") diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala deleted file mode 100644 index 46116be6d7c..00000000000 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala +++ /dev/null @@ -1,246 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package kafka.coordinator.transaction - -import org.apache.kafka.common.TopicPartition -import org.apache.kafka.common.compress.Compression -import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} -import org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection -import org.apache.kafka.common.protocol.types.{CompactArrayOf, Field, Schema, Struct, Type} -import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, SimpleRecord} -import org.apache.kafka.coordinator.transaction.{TransactionMetadata, TransactionState, TxnTransitMetadata} -import org.apache.kafka.coordinator.transaction.generated.{TransactionLogKey, TransactionLogValue} -import org.apache.kafka.server.common.TransactionVersion.{TV_0, TV_2} -import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue, fail} -import org.junit.jupiter.api.Test - -import java.nio.ByteBuffer -import java.util -import scala.collection.Seq -import scala.jdk.CollectionConverters._ - -class TransactionLogTest { - - val producerEpoch: Short = 0 - val transactionTimeoutMs: Int = 1000 - - val topicPartitions = util.Set.of(new TopicPartition("topic1", 0), - new TopicPartition("topic1", 1), - new TopicPartition("topic2", 0), - new TopicPartition("topic2", 1), - new TopicPartition("topic2", 2)) - - @Test - def shouldThrowExceptionWriteInvalidTxn(): Unit = { - val transactionalId = "transactionalId" - val producerId = 23423L - - val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, TransactionState.EMPTY, util.Set.of, 0, 0, TV_0) - txnMetadata.addPartitions(topicPartitions) - - assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2)) - } - - @Test - def shouldReadWriteMessages(): Unit = { - val pidMappings = Map[String, Long]("zero" -> 0L, - "one" -> 1L, - "two" -> 2L, - "three" -> 3L, - "four" -> 4L, - "five" -> 5L) - - val transactionStates = Map[Long, TransactionState](0L -> TransactionState.EMPTY, - 1L -> TransactionState.ONGOING, - 2L -> TransactionState.PREPARE_COMMIT, - 3L -> TransactionState.COMPLETE_COMMIT, - 4L -> TransactionState.PREPARE_ABORT, - 5L -> TransactionState.COMPLETE_ABORT) - - // generate transaction log messages - val txnRecords = pidMappings.map { case (transactionalId, producerId) => - val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_ID, producerEpoch, - RecordBatch.NO_PRODUCER_EPOCH, transactionTimeoutMs, transactionStates(producerId), util.Set.of, 0, 0, TV_0) - - if (!txnMetadata.state.equals(TransactionState.EMPTY)) - txnMetadata.addPartitions(topicPartitions) - - val keyBytes = TransactionLog.keyToBytes(transactionalId) - val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2) - - new SimpleRecord(keyBytes, valueBytes) - }.toSeq - - val records = MemoryRecords.withRecords(0, Compression.NONE, txnRecords: _*) - - var count = 0 - for (record <- records.records.asScala) { - TransactionLog.readTxnRecordKey(record.key) match { - case Left(version) => fail(s"Unexpected record version: $version") - case Right(transactionalId) => - val txnMetadata = TransactionLog.readTxnRecordValue(transactionalId, record.value).get - - assertEquals(pidMappings(transactionalId), txnMetadata.producerId) - assertEquals(producerEpoch, txnMetadata.producerEpoch) - assertEquals(transactionTimeoutMs, txnMetadata.txnTimeoutMs) - assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state) - - if (txnMetadata.state.equals(TransactionState.EMPTY)) - assertEquals(util.Set.of, txnMetadata.topicPartitions) - else - assertEquals(topicPartitions, txnMetadata.topicPartitions) - - count = count + 1 - } - } - - assertEquals(pidMappings.size, count) - } - - @Test - def testSerializeTransactionLogValueToHighestNonFlexibleVersion(): Unit = { - val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, new util.HashSet(), 500, 500, TV_0) - val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0)) - assertEquals(0, txnLogValueBuffer.getShort) - } - - @Test - def testSerializeTransactionLogValueToFlexibleVersion(): Unit = { - val txnTransitMetadata = new TxnTransitMetadata(1, 1, 1, 1, 1, 1000, TransactionState.COMPLETE_COMMIT, new util.HashSet(), 500, 500, TV_2) - val txnLogValueBuffer = ByteBuffer.wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2)) - assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort) - } - - @Test - def testDeserializeHighestSupportedTransactionLogValue(): Unit = { - val txnPartitions = new TransactionLogValue.PartitionsSchema() - .setTopic("topic") - .setPartitionIds(util.List.of(0)) - - val txnLogValue = new TransactionLogValue() - .setProducerId(100) - .setProducerEpoch(50.toShort) - .setTransactionStatus(TransactionState.COMPLETE_COMMIT.id) - .setTransactionStartTimestampMs(750L) - .setTransactionLastUpdateTimestampMs(1000L) - .setTransactionTimeoutMs(500) - .setTransactionPartitions(util.List.of(txnPartitions)) - - val serialized = MessageUtil.toVersionPrefixedByteBuffer(1, txnLogValue) - val deserialized = TransactionLog.readTxnRecordValue("transactionId", serialized).get - - assertEquals(100, deserialized.producerId) - assertEquals(50, deserialized.producerEpoch) - assertEquals(TransactionState.COMPLETE_COMMIT, deserialized.state) - assertEquals(750L, deserialized.txnStartTimestamp) - assertEquals(1000L, deserialized.txnLastUpdateTimestamp) - assertEquals(500, deserialized.txnTimeoutMs) - - val actualTxnPartitions = deserialized.topicPartitions - assertEquals(1, actualTxnPartitions.size) - assertTrue(actualTxnPartitions.contains(new TopicPartition("topic", 0))) - } - - @Test - def testDeserializeFutureTransactionLogValue(): Unit = { - // Copy of TransactionLogValue.PartitionsSchema.SCHEMA_1 with a few - // additional tagged fields. - val futurePartitionsSchema = new Schema( - new Field("topic", Type.COMPACT_STRING, ""), - new Field("partition_ids", new CompactArrayOf(Type.INT32), ""), - TaggedFieldsSection.of( - Int.box(100), new Field("partition_foo", Type.STRING, ""), - Int.box(101), new Field("partition_foo", Type.INT32, "") - ) - ) - - // Create TransactionLogValue.PartitionsSchema with tagged fields - val txnPartitions = new Struct(futurePartitionsSchema) - txnPartitions.set("topic", "topic") - txnPartitions.set("partition_ids", Array(Integer.valueOf(1))) - val txnPartitionsTaggedFields = new util.TreeMap[Integer, Any]() - txnPartitionsTaggedFields.put(100, "foo") - txnPartitionsTaggedFields.put(101, 4000) - txnPartitions.set("_tagged_fields", txnPartitionsTaggedFields) - - // Copy of TransactionLogValue.SCHEMA_1 with a few - // additional tagged fields. - val futureTransactionLogValueSchema = new Schema( - new Field("producer_id", Type.INT64, ""), - new Field("producer_epoch", Type.INT16, ""), - new Field("transaction_timeout_ms", Type.INT32, ""), - new Field("transaction_status", Type.INT8, ""), - new Field("transaction_partitions", CompactArrayOf.nullable(futurePartitionsSchema), ""), - new Field("transaction_last_update_timestamp_ms", Type.INT64, ""), - new Field("transaction_start_timestamp_ms", Type.INT64, ""), - TaggedFieldsSection.of( - Int.box(100), new Field("txn_foo", Type.STRING, ""), - Int.box(101), new Field("txn_bar", Type.INT32, "") - ) - ) - - // Create TransactionLogValue with tagged fields - val transactionLogValue = new Struct(futureTransactionLogValueSchema) - transactionLogValue.set("producer_id", 1000L) - transactionLogValue.set("producer_epoch", 100.toShort) - transactionLogValue.set("transaction_timeout_ms", 1000) - transactionLogValue.set("transaction_status", TransactionState.COMPLETE_COMMIT.id) - transactionLogValue.set("transaction_partitions", Array(txnPartitions)) - transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L) - transactionLogValue.set("transaction_start_timestamp_ms", 3000L) - val txnLogValueTaggedFields = new util.TreeMap[Integer, Any]() - txnLogValueTaggedFields.put(100, "foo") - txnLogValueTaggedFields.put(101, 4000) - transactionLogValue.set("_tagged_fields", txnLogValueTaggedFields) - - // Prepare the buffer. - val buffer = ByteBuffer.allocate(transactionLogValue.sizeOf() + 2) - buffer.put(0.toByte) - buffer.put(1.toByte) // Add 1 as version. - transactionLogValue.writeTo(buffer) - buffer.flip() - - // Read the buffer with the real schema and verify that tagged - // fields were read but ignored. - buffer.getShort() // Skip version. - val value = new TransactionLogValue(new ByteBufferAccessor(buffer), 1.toShort) - assertEquals(Seq(100, 101), value.unknownTaggedFields().asScala.map(_.tag)) - assertEquals(Seq(100, 101), value.transactionPartitions().get(0).unknownTaggedFields().asScala.map(_.tag)) - - // Read the buffer with readTxnRecordValue. - buffer.rewind() - val txnMetadata = TransactionLog.readTxnRecordValue("transaction-id", buffer).get - assertEquals(1000L, txnMetadata.producerId) - assertEquals(100, txnMetadata.producerEpoch) - assertEquals(1000L, txnMetadata.txnTimeoutMs) - assertEquals(TransactionState.COMPLETE_COMMIT, txnMetadata.state) - assertEquals(util.Set.of(new TopicPartition("topic", 1)), txnMetadata.topicPartitions) - assertEquals(2000L, txnMetadata.txnLastUpdateTimestamp) - assertEquals(3000L, txnMetadata.txnStartTimestamp) - } - - @Test - def testReadTxnRecordKeyCanReadUnknownMessage(): Unit = { - val record = new TransactionLogKey() - val unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, record) - TransactionLog.readTxnRecordKey(ByteBuffer.wrap(unknownRecord)) match { - case Left(version) => assertEquals(Short.MaxValue, version) - case Right(_) => fail("Expected to read unknown message") - } - } -} diff --git a/transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionLogTest.java b/transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionLogTest.java new file mode 100644 index 00000000000..3295901a8ba --- /dev/null +++ b/transaction-coordinator/src/test/java/org/apache/kafka/coordinator/transaction/TransactionLogTest.java @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.coordinator.transaction; + +import kafka.coordinator.transaction.TransactionLog; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.compress.Compression; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.common.protocol.types.CompactArrayOf; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.RawTaggedField; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.coordinator.transaction.generated.TransactionLogKey; +import org.apache.kafka.coordinator.transaction.generated.TransactionLogValue; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.nio.ByteBuffer; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.TreeMap; +import java.util.stream.Stream; + +import static java.nio.ByteBuffer.allocate; +import static java.nio.ByteBuffer.wrap; +import static org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection; +import static org.apache.kafka.server.common.TransactionVersion.LATEST_PRODUCTION; +import static org.apache.kafka.server.common.TransactionVersion.TV_0; +import static org.apache.kafka.server.common.TransactionVersion.TV_2; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +class TransactionLogTest { + + private final Set topicPartitions = Set.of( + new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1), + new TopicPartition("topic2", 0), + new TopicPartition("topic2", 1), + new TopicPartition("topic2", 2) + ); + + private sealed interface TxnKeyResult { + record UnknownVersion(short version) implements TxnKeyResult { } + record TransactionalId(String id) implements TxnKeyResult { } + } + + private static TxnKeyResult readTxnRecordKey(ByteBuffer buf) { + var e = TransactionLog.readTxnRecordKey(buf); + return e.isLeft() + ? new TxnKeyResult.UnknownVersion((Short) e.left().get()) + : new TxnKeyResult.TransactionalId(e.right().get()); + } + + private static TransactionMetadata TransactionMetadata(TransactionState state) { + return new TransactionMetadata( + "transactionalId", + 0L, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_ID, + (short) 0, + RecordBatch.NO_PRODUCER_EPOCH, + 1000, + state, + Set.of(), + 0, + 0, + LATEST_PRODUCTION + ); + } + + private static Stream transactionStatesProvider() { + return Stream.of( + TransactionState.EMPTY, + TransactionState.ONGOING, + TransactionState.PREPARE_COMMIT, + TransactionState.COMPLETE_COMMIT, + TransactionState.PREPARE_ABORT, + TransactionState.COMPLETE_ABORT + ); + } + + @Test + void shouldThrowExceptionWriteInvalidTxn() { + var txnMetadata = TransactionMetadata(TransactionState.EMPTY); + txnMetadata.addPartitions(topicPartitions); + + var preparedMetadata = txnMetadata.prepareNoTransit(); + assertThrows(IllegalStateException.class, () -> TransactionLog.valueToBytes(preparedMetadata, TV_2)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("transactionStatesProvider") + void shouldReadWriteMessages(TransactionState state) { + var txnMetadata = TransactionMetadata(state); + if (state != TransactionState.EMPTY) { + txnMetadata.addPartitions(topicPartitions); + } + + var record = MemoryRecords.withRecords(Compression.NONE, new SimpleRecord( + TransactionLog.keyToBytes(txnMetadata.transactionalId()), + TransactionLog.valueToBytes(txnMetadata.prepareNoTransit(), TV_2) + )).records().iterator().next(); + var txnIdResult = assertInstanceOf(TxnKeyResult.TransactionalId.class, readTxnRecordKey(record.key())); + var deserialized = TransactionLog.readTxnRecordValue(txnIdResult.id(), record.value()).get(); + + assertEquals(txnMetadata.producerId(), deserialized.producerId()); + assertEquals(txnMetadata.producerEpoch(), deserialized.producerEpoch()); + assertEquals(txnMetadata.txnTimeoutMs(), deserialized.txnTimeoutMs()); + assertEquals(txnMetadata.state(), deserialized.state()); + + if (txnMetadata.state() == TransactionState.EMPTY) { + assertEquals(Set.of(), deserialized.topicPartitions()); + } else { + assertEquals(topicPartitions, deserialized.topicPartitions()); + } + } + + @Test + void testSerializeTransactionLogValueToHighestNonFlexibleVersion() { + var txnTransitMetadata = new TxnTransitMetadata(1L, 1L, 1L, (short) 1, (short) 1, 1000, TransactionState.COMPLETE_COMMIT, new HashSet<>(), 500L, 500L, TV_0); + var txnLogValueBuffer = wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_0)); + assertEquals(TV_0.transactionLogValueVersion(), txnLogValueBuffer.getShort()); + } + + @Test + void testSerializeTransactionLogValueToFlexibleVersion() { + var txnTransitMetadata = new TxnTransitMetadata(1L, 1L, 1L, (short) 1, (short) 1, 1000, TransactionState.COMPLETE_COMMIT, new HashSet<>(), 500L, 500L, TV_2); + var txnLogValueBuffer = wrap(TransactionLog.valueToBytes(txnTransitMetadata, TV_2)); + assertEquals(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, txnLogValueBuffer.getShort()); + } + + @Test + void testDeserializeHighestSupportedTransactionLogValue() { + var txnPartitions = new TransactionLogValue.PartitionsSchema() + .setTopic("topic") + .setPartitionIds(List.of(0)); + + var txnLogValue = new TransactionLogValue() + .setProducerId(100) + .setProducerEpoch((short) 50) + .setTransactionStatus(TransactionState.COMPLETE_COMMIT.id()) + .setTransactionStartTimestampMs(750L) + .setTransactionLastUpdateTimestampMs(1000L) + .setTransactionTimeoutMs(500) + .setTransactionPartitions(List.of(txnPartitions)); + + var serialized = MessageUtil.toVersionPrefixedByteBuffer((short) 1, txnLogValue); + var deserialized = TransactionLog.readTxnRecordValue("transactionId", serialized).get(); + + assertEquals(100, deserialized.producerId()); + assertEquals(50, deserialized.producerEpoch()); + assertEquals(TransactionState.COMPLETE_COMMIT, deserialized.state()); + assertEquals(750L, deserialized.txnStartTimestamp()); + assertEquals(1000L, deserialized.txnLastUpdateTimestamp()); + assertEquals(500, deserialized.txnTimeoutMs()); + + var actualTxnPartitions = deserialized.topicPartitions(); + assertEquals(1, actualTxnPartitions.size()); + assertTrue(actualTxnPartitions.contains(new TopicPartition("topic", 0))); + } + + @Test + void testDeserializeFutureTransactionLogValue() { + // Copy of TransactionLogValue.PartitionsSchema.SCHEMA_1 with a few + // additional tagged fields. + var futurePartitionsSchema = new Schema( + new Field("topic", Type.COMPACT_STRING, ""), + new Field("partition_ids", new CompactArrayOf(Type.INT32), ""), + TaggedFieldsSection.of( + 100, new Field("partition_foo", Type.STRING, ""), + 101, new Field("partition_foo", Type.INT32, "") + ) + ); + + // Create TransactionLogValue.PartitionsSchema with tagged fields + var txnPartitions = new Struct(futurePartitionsSchema); + txnPartitions.set("topic", "topic"); + txnPartitions.set("partition_ids", new Integer[]{1}); + var txnPartitionsTaggedFields = new TreeMap(); + txnPartitionsTaggedFields.put(100, "foo"); + txnPartitionsTaggedFields.put(101, 4000); + txnPartitions.set("_tagged_fields", txnPartitionsTaggedFields); + + // Copy of TransactionLogValue.SCHEMA_1 with a few + // additional tagged fields. + var futureTransactionLogValueSchema = new Schema( + new Field("producer_id", Type.INT64, ""), + new Field("producer_epoch", Type.INT16, ""), + new Field("transaction_timeout_ms", Type.INT32, ""), + new Field("transaction_status", Type.INT8, ""), + new Field("transaction_partitions", CompactArrayOf.nullable(futurePartitionsSchema), ""), + new Field("transaction_last_update_timestamp_ms", Type.INT64, ""), + new Field("transaction_start_timestamp_ms", Type.INT64, ""), + TaggedFieldsSection.of( + 100, new Field("txn_foo", Type.STRING, ""), + 101, new Field("txn_bar", Type.INT32, "") + ) + ); + + // Create TransactionLogValue with tagged fields + var transactionLogValue = new Struct(futureTransactionLogValueSchema); + transactionLogValue.set("producer_id", 1000L); + transactionLogValue.set("producer_epoch", (short) 100); + transactionLogValue.set("transaction_timeout_ms", 1000); + transactionLogValue.set("transaction_status", TransactionState.COMPLETE_COMMIT.id()); + transactionLogValue.set("transaction_partitions", new Struct[]{txnPartitions}); + transactionLogValue.set("transaction_last_update_timestamp_ms", 2000L); + transactionLogValue.set("transaction_start_timestamp_ms", 3000L); + var txnLogValueTaggedFields = new TreeMap(); + txnLogValueTaggedFields.put(100, "foo"); + txnLogValueTaggedFields.put(101, 4000); + transactionLogValue.set("_tagged_fields", txnLogValueTaggedFields); + + // Prepare the buffer. + var buffer = allocate(Short.BYTES + transactionLogValue.sizeOf()); + buffer.putShort((short) 1); // Add 1 as version. + transactionLogValue.writeTo(buffer); + buffer.flip(); + + // Read the buffer with the real schema and verify that tagged + // fields were read but ignored. + buffer.getShort(); // Skip version. + var value = new TransactionLogValue(new ByteBufferAccessor(buffer), (short) 1); + assertEquals(List.of(100, 101), value.unknownTaggedFields().stream().map(RawTaggedField::tag).toList()); + assertEquals(List.of(100, 101), value.transactionPartitions().get(0).unknownTaggedFields().stream().map(RawTaggedField::tag).toList()); + + // Read the buffer with readTxnRecordValue. + buffer.rewind(); + var txnMetadata = TransactionLog.readTxnRecordValue("transaction-id", buffer); + + assertFalse(txnMetadata.isEmpty(), "Expected transaction metadata but got none"); + + var metadata = txnMetadata.get(); + assertEquals(1000L, metadata.producerId()); + assertEquals(100, metadata.producerEpoch()); + assertEquals(1000, metadata.txnTimeoutMs()); + assertEquals(TransactionState.COMPLETE_COMMIT, metadata.state()); + assertEquals(Set.of(new TopicPartition("topic", 1)), metadata.topicPartitions()); + assertEquals(2000L, metadata.txnLastUpdateTimestamp()); + assertEquals(3000L, metadata.txnStartTimestamp()); + } + + @Test + void testReadTxnRecordKeyCanReadUnknownMessage() { + var unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MAX_VALUE, new TransactionLogKey()); + var result = readTxnRecordKey(wrap(unknownRecord)); + + var uv = assertInstanceOf(TxnKeyResult.UnknownVersion.class, result); + assertEquals(Short.MAX_VALUE, uv.version()); + } + + @Test + void shouldReturnEmptyWhenForTombstoneRecord() { + assertTrue(TransactionLog.readTxnRecordValue("transaction-id", null).isEmpty()); + } +} \ No newline at end of file