KAFKA-17137 Add integration tests for admin client (Transaction and UserScramCredentials related) (#16652)

Reviewers: Chia-Ping Tsai <chia7712@gmail.com>
This commit is contained in:
xijiu 2024-08-07 01:11:55 +08:00 committed by GitHub
parent 4c9795eddf
commit 46f1f0268b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 364 additions and 2 deletions

View File

@ -39,6 +39,7 @@ import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, Produce
import org.apache.kafka.common.acl.{AccessControlEntry, AclBinding, AclBindingFilter, AclOperation, AclPermissionType}
import org.apache.kafka.common.config.{ConfigResource, LogLevelConfig, SslConfigs, TopicConfig}
import org.apache.kafka.common.errors._
import org.apache.kafka.common.internals.Topic
import org.apache.kafka.common.requests.{DeleteRecordsRequest, MetadataResponse}
import org.apache.kafka.common.resource.{PatternType, ResourcePattern, ResourceType}
import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer}
@ -52,7 +53,7 @@ import org.apache.kafka.server.config.{QuotaConfigs, ServerConfigs, ServerLogCon
import org.apache.kafka.storage.internals.log.{CleanerConfig, LogConfig}
import org.apache.kafka.test.TestUtils.DEFAULT_MAX_WAIT_MS
import org.junit.jupiter.api.Assertions._
import org.junit.jupiter.api.{AfterEach, BeforeEach, Disabled, TestInfo}
import org.junit.jupiter.api.{AfterEach, BeforeEach, Disabled, TestInfo, Timeout}
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import org.slf4j.LoggerFactory
@ -94,6 +95,367 @@ class PlaintextAdminIntegrationTest extends BaseAdminIntegrationTest {
super.tearDown()
}
@ParameterizedTest
@ValueSource(strings = Array("zk", "kraft"))
def testDescribeUserScramCredentials(quorum: String): Unit = {
client = createAdminClient
// add a new user
val targetUserName = "tom"
client.alterUserScramCredentials(Collections.singletonList(
new UserScramCredentialUpsertion(targetUserName, new ScramCredentialInfo(ScramMechanism.SCRAM_SHA_256, 4096), "123456")
)).all.get
TestUtils.waitUntilTrue(() => client.describeUserScramCredentials().all().get().size() == 1,
"Add one user scram credential timeout")
val result = client.describeUserScramCredentials().all().get()
result.forEach((userName, scramDescription) => {
assertEquals(targetUserName, userName)
assertEquals(targetUserName, scramDescription.name())
val credentialInfos = scramDescription.credentialInfos()
assertEquals(1, credentialInfos.size())
assertEquals(ScramMechanism.SCRAM_SHA_256, credentialInfos.get(0).mechanism())
assertEquals(4096, credentialInfos.get(0).iterations())
})
// add other users
client.alterUserScramCredentials(util.Arrays.asList(
new UserScramCredentialUpsertion("tom2", new ScramCredentialInfo(ScramMechanism.SCRAM_SHA_256, 4096), "123456"),
new UserScramCredentialUpsertion("tom3", new ScramCredentialInfo(ScramMechanism.SCRAM_SHA_256, 4096), "123456")
)).all().get
TestUtils.waitUntilTrue(() => client.describeUserScramCredentials().all().get().size() == 3,
"Add user scram credential timeout")
// alter user info
client.alterUserScramCredentials(Collections.singletonList(
new UserScramCredentialUpsertion(targetUserName, new ScramCredentialInfo(ScramMechanism.SCRAM_SHA_512, 8192), "123456")
)).all.get
TestUtils.waitUntilTrue(() => {
client.describeUserScramCredentials().all().get().get(targetUserName).credentialInfos().size() == 2
}, "Alter user scram credential timeout")
val userTomResult = client.describeUserScramCredentials().all().get()
assertEquals(3, userTomResult.size())
val userScramCredential = userTomResult.get(targetUserName)
assertEquals(targetUserName, userScramCredential.name())
val credentialInfos = userScramCredential.credentialInfos()
assertEquals(2, credentialInfos.size())
val credentialList = credentialInfos.asScala.sortBy(s => s.mechanism().`type`())
assertEquals(ScramMechanism.SCRAM_SHA_256, credentialList.head.mechanism())
assertEquals(4096, credentialList.head.iterations())
assertEquals(ScramMechanism.SCRAM_SHA_512, credentialList(1).mechanism())
assertEquals(8192, credentialList(1).iterations())
// test describeUserScramCredentials(List<String> users)
val userAndScramMap = client.describeUserScramCredentials(Collections.singletonList("tom2")).all().get()
assertEquals(1, userAndScramMap.size())
val scram = userAndScramMap.get("tom2")
assertNotNull(scram)
val credentialInfo = scram.credentialInfos().get(0)
assertEquals(ScramMechanism.SCRAM_SHA_256, credentialInfo.mechanism())
assertEquals(4096, credentialInfo.iterations())
}
private def createInvalidAdminClient(): Admin = {
val config = createConfig
config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, s"localhost:${TestUtils.IncorrectBrokerPort}")
Admin.create(config)
}
@ParameterizedTest
@Timeout(10)
@ValueSource(strings = Array("zk", "kraft"))
def testDescribeUserScramCredentialsTimeout(quorum: String): Unit = {
client = createInvalidAdminClient()
try {
// test describeUserScramCredentials(List<String> users, DescribeUserScramCredentialsOptions options)
val exception = assertThrows(classOf[ExecutionException], () => {
client.describeUserScramCredentials(Collections.singletonList("tom4"),
new DescribeUserScramCredentialsOptions().timeoutMs(0)).all().get()
})
assertInstanceOf(classOf[TimeoutException], exception.getCause)
} finally client.close(time.Duration.ZERO)
}
private def consumeToExpectedNumber = (expectedNumber: Int) => {
val configs = new util.HashMap[String, Object]()
configs.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, plaintextBootstrapServers(brokers))
configs.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, IsolationLevel.READ_COMMITTED.toString)
val consumer = new KafkaConsumer(configs, new ByteArrayDeserializer, new ByteArrayDeserializer)
try {
consumer.assign(Collections.singleton(topicPartition))
consumer.seekToBeginning(Collections.singleton(topicPartition))
var consumeNum = 0
TestUtils.waitUntilTrue(() => {
val records = consumer.poll(time.Duration.ofMillis(100))
consumeNum += records.count()
consumeNum >= expectedNumber
}, "consumeToExpectedNumber timeout")
} finally consumer.close()
}
@ParameterizedTest
@ValueSource(strings = Array("zk", "kraft"))
def testDescribeProducers(quorum: String): Unit = {
client = createAdminClient
client.createTopics(Collections.singletonList(new NewTopic(topic, 1, 1.toShort))).all().get()
def appendCommonRecords = (records: Int) => {
val producer = new KafkaProducer(Collections.singletonMap(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG,
plaintextBootstrapServers(brokers).asInstanceOf[Object]), new ByteArraySerializer, new ByteArraySerializer)
try {
(0 until records).foreach(i =>
producer.send(new ProducerRecord[Array[Byte], Array[Byte]](
topic, partition, i.toString.getBytes, i.toString.getBytes())))
} finally producer.close()
}
def appendTransactionRecords(transactionId: String, records: Int, commit: Boolean): KafkaProducer[Array[Byte], Array[Byte]] = {
val producer = TestUtils.createTransactionalProducer(transactionId, brokers)
producer.initTransactions()
producer.beginTransaction()
(0 until records).foreach(i =>
producer.send(new ProducerRecord[Array[Byte], Array[Byte]](
topic, partition, i.toString.getBytes, i.toString.getBytes())))
producer.flush()
if (commit) {
producer.commitTransaction()
producer.close()
}
producer
}
def queryProducerDetail() = client
.describeProducers(Collections.singletonList(topicPartition))
.partitionResult(topicPartition).get().activeProducers().asScala
// send common msg
appendCommonRecords(1)
val producerIterator = queryProducerDetail()
assertEquals(1, producerIterator.size)
val producerState = producerIterator.last
assertEquals(0, producerState.producerEpoch())
assertFalse(producerState.coordinatorEpoch().isPresent)
assertFalse(producerState.currentTransactionStartOffset().isPresent)
// send committed transaction msg
appendTransactionRecords("foo", 2, commit = true)
// consume 3 records to ensure transaction finished
consumeToExpectedNumber(3)
val transactionProducerIterator = queryProducerDetail()
assertEquals(2, transactionProducerIterator.size)
val containsCoordinatorEpochIterator = transactionProducerIterator
.filter(producer => producer.coordinatorEpoch().isPresent)
assertEquals(1, containsCoordinatorEpochIterator.size)
val transactionProducerState = containsCoordinatorEpochIterator.last
assertFalse(transactionProducerState.currentTransactionStartOffset().isPresent)
// send ongoing transaction msg
val ongoingProducer = appendTransactionRecords("foo3", 3, commit = false)
try {
val transactionNoneCommitProducerIterator = queryProducerDetail()
assertEquals(3, transactionNoneCommitProducerIterator.size)
val containsOngoingIterator = transactionNoneCommitProducerIterator
.filter(producer => producer.currentTransactionStartOffset().isPresent)
assertEquals(1, containsOngoingIterator.size)
val ongoingTransactionProducerState = containsOngoingIterator.last
// we send (1 common msg) + (2 transaction msg) + (1 transaction marker msg), so transactionStartOffset is 4
assertEquals(4, ongoingTransactionProducerState.currentTransactionStartOffset().getAsLong)
} finally ongoingProducer.close()
}
@ParameterizedTest
@ValueSource(strings = Array("zk", "kraft"))
def testDescribeTransactions(quorum: String): Unit = {
client = createAdminClient
client.createTopics(Collections.singletonList(new NewTopic(topic, 1, 1.toShort))).all().get()
var transactionId = "foo"
def describeTransactions(): TransactionDescription = {
client.describeTransactions(Collections.singleton(transactionId)).description(transactionId).get()
}
def transactionState(): TransactionState = {
describeTransactions().state()
}
def findCoordinatorIdByTransactionId(transactionId: String): Int = {
// calculate the transaction partition id
val transactionPartitionId = Utils.abs(transactionId.hashCode) %
brokers.head.metadataCache.numPartitions(Topic.TRANSACTION_STATE_TOPIC_NAME).get
val transactionTopic = client.describeTopics(Collections.singleton(Topic.TRANSACTION_STATE_TOPIC_NAME))
val partitionList = transactionTopic.allTopicNames().get().get(Topic.TRANSACTION_STATE_TOPIC_NAME).partitions()
partitionList.asScala.filter(tp => tp.partition() == transactionPartitionId).head.leader().id()
}
// normal commit case
val producer = TestUtils.createTransactionalProducer(transactionId, brokers)
try {
// init, the transaction is not begin, so TransactionalIdNotFoundException is expected
val exception = assertThrows(classOf[ExecutionException], () => transactionState())
assertInstanceOf(classOf[TransactionalIdNotFoundException], exception.getCause)
producer.initTransactions()
assertEquals(TransactionState.EMPTY, transactionState())
producer.beginTransaction()
assertEquals(TransactionState.EMPTY, transactionState())
producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, "k1".getBytes, "v1".getBytes()))
producer.flush()
assertEquals(TransactionState.ONGOING, transactionState())
TestUtils.waitUntilTrue(() => describeTransactions().topicPartitions().size() == 1, "Describe transactions timeout")
val transactionResult = describeTransactions()
assertEquals(findCoordinatorIdByTransactionId(transactionId), transactionResult.coordinatorId())
assertEquals(0, transactionResult.producerId())
assertEquals(0, transactionResult.producerEpoch())
assertEquals(Collections.singleton(topicPartition), transactionResult.topicPartitions())
producer.commitTransaction()
val state = transactionState()
// Either PREPARE_COMMIT or COMPLETE_COMMIT is expected
assertTrue(state == TransactionState.PREPARE_COMMIT || state == TransactionState.COMPLETE_COMMIT)
// producer commit transaction, but maybe transaction coordinator has not been submitted mark msg
// so we start up a consumer and consume the expected number of msg, to ensure transaction committed
consumeToExpectedNumber(1)
assertEquals(TransactionState.COMPLETE_COMMIT, transactionState())
} finally producer.close()
// abort case
transactionId = "foo2"
val abortProducer = TestUtils.createTransactionalProducer(transactionId, brokers)
try {
// init, the transaction is not begin, so TransactionalIdNotFoundException is expected
val exception = assertThrows(classOf[ExecutionException], () => transactionState())
assertTrue(exception.getCause.isInstanceOf[TransactionalIdNotFoundException])
abortProducer.initTransactions()
assertEquals(TransactionState.EMPTY, transactionState())
abortProducer.beginTransaction()
assertEquals(TransactionState.EMPTY, transactionState())
val transactionResult = describeTransactions()
assertEquals(findCoordinatorIdByTransactionId(transactionId), transactionResult.coordinatorId())
assertEquals(0, transactionResult.topicPartitions().size())
abortProducer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, "k1".getBytes, "v1".getBytes()))
abortProducer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, "k2".getBytes, "v2".getBytes()))
abortProducer.flush()
val transactionSendMsgResult = describeTransactions()
assertEquals(findCoordinatorIdByTransactionId(transactionId), transactionSendMsgResult.coordinatorId())
assertEquals(Collections.singleton(topicPartition), transactionSendMsgResult.topicPartitions())
assertEquals(topicPartition, transactionSendMsgResult.topicPartitions().asScala.head)
assertEquals(TransactionState.ONGOING, transactionState())
abortProducer.abortTransaction()
val state = transactionState()
assertTrue(state == TransactionState.PREPARE_ABORT || state == TransactionState.COMPLETE_ABORT)
// producer commit transaction, but maybe transaction coordinator has not been submitted mark msg
// so we start up a consumer and consume the expected number of msg, to ensure transaction committed
consumeToExpectedNumber(1)
assertEquals(TransactionState.COMPLETE_ABORT, transactionState())
} finally abortProducer.close()
}
@ParameterizedTest
@Timeout(10)
@ValueSource(strings = Array("zk", "kraft"))
def testDescribeTransactionsTimeout(quorum: String): Unit = {
client = createInvalidAdminClient()
try {
val transactionId = "foo"
val exception = assertThrows(classOf[ExecutionException], () => {
client.describeTransactions(Collections.singleton(transactionId),
new DescribeTransactionsOptions().timeoutMs(0)).description(transactionId).get()
})
assertInstanceOf(classOf[TimeoutException], exception.getCause)
} finally client.close(time.Duration.ZERO)
}
@ParameterizedTest
@Timeout(10)
@ValueSource(strings = Array("zk", "kraft"))
def testAbortTransactionTimeout(quorum: String): Unit = {
client = createInvalidAdminClient()
try {
val exception = assertThrows(classOf[ExecutionException], () => {
client.abortTransaction(
new AbortTransactionSpec(topicPartition, 1, 1, 1),
new AbortTransactionOptions().timeoutMs(0)).all().get()
})
assertInstanceOf(classOf[TimeoutException], exception.getCause)
} finally client.close(time.Duration.ZERO)
}
@ParameterizedTest
@ValueSource(strings = Array("zk", "kraft"))
def testListTransactions(quorum: String): Unit = {
def createTransactionList(): Unit = {
client = createAdminClient
client.createTopics(Collections.singletonList(new NewTopic(topic, 1, 1.toShort))).all().get()
val producer = TestUtils.createTransactionalProducer("foo", brokers)
try {
producer.initTransactions()
producer.beginTransaction()
producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, "k1".getBytes, "v1".getBytes()))
producer.flush()
producer.commitTransaction()
} finally producer.close()
val producer2 = TestUtils.createTransactionalProducer("foo2", brokers)
try {
producer2.initTransactions()
producer2.beginTransaction()
producer2.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, "k1".getBytes, "v1".getBytes()))
producer2.flush()
producer2.abortTransaction()
} finally producer2.close()
val producer3 = TestUtils.createTransactionalProducer("foo3", brokers)
try {
producer3.initTransactions()
producer3.beginTransaction()
producer3.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, "k1".getBytes, "v1".getBytes()))
producer3.flush()
producer3.commitTransaction()
} finally producer3.close()
consumeToExpectedNumber(2)
}
createTransactionList()
assertEquals(3, client.listTransactions().all().get().size())
assertEquals(2, client.listTransactions(new ListTransactionsOptions()
.filterStates(Collections.singletonList(TransactionState.COMPLETE_COMMIT))).all().get().size())
assertEquals(1, client.listTransactions(new ListTransactionsOptions()
.filterStates(Collections.singletonList(TransactionState.COMPLETE_ABORT))).all().get().size())
assertEquals(1, client.listTransactions(new ListTransactionsOptions()
.filterProducerIds(Collections.singletonList(0L))).all().get().size())
// ensure all transaction's txnStartTimestamp >= 500
Thread.sleep(501)
assertEquals(3, client.listTransactions(new ListTransactionsOptions().filterOnDuration(500)).all().get().size())
val producerNew = TestUtils.createTransactionalProducer("foo4", brokers)
try {
producerNew.initTransactions()
producerNew.beginTransaction()
producerNew.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, "k1".getBytes, "v1".getBytes()))
producerNew.flush()
val transactionList = client.listTransactions(new ListTransactionsOptions().filterOnDuration(500)).all().get()
// current transaction start time is now, so transactionList size is still 3
assertEquals(3, transactionList.size())
// transactionList not contains 'foo4'
assertEquals(0, transactionList.asScala.count(t => t.transactionalId().equals("foo4")))
} finally producerNew.close()
}
@ParameterizedTest
@ValueSource(strings = Array("zk", "kraft", "kraft+kip848"))
def testAbortTransaction(quorum: String): Unit = {
@ -599,7 +961,7 @@ class PlaintextAdminIntegrationTest extends BaseAdminIntegrationTest {
// try a newCount which would be a decrease
alterResult = client.createPartitions(Map(topic1 ->
NewPartitions.increaseTo(1)).asJava, option)
var e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get,
() => s"$desc: Expect InvalidPartitionsException when newCount is a decrease")
assertTrue(e.getCause.isInstanceOf[InvalidPartitionsException], desc)