KAFKA-3328: SimpleAclAuthorizer can lose ACLs with frequent add/remov…

…e calls

Changes the SimpleAclAuthorizer to:
- Track and utilize the zookeeper version when updating zookeeper to prevent data loss in the case of stale reads and race conditions
- Update local cache when modifying ACLs
- Add debug logging

Author: Grant Henke <granthenke@gmail.com>
Author: Grant Henke <granthenke@users.noreply.github.com>
Author: Ismael Juma <ismael@juma.me.uk>

Reviewers: Flavio Junqueira, Jun Rao, Ismael Juma, Gwen Shapira

Closes #1006 from granthenke/simple-authorizer-fix
This commit is contained in:
Grant Henke 2016-03-20 00:46:12 -07:00 committed by Gwen Shapira
parent eb823281a5
commit bfac36ad0e
4 changed files with 318 additions and 86 deletions

View File

@ -19,19 +19,20 @@ package kafka.security.auth
import java.util import java.util
import java.util.concurrent.locks.ReentrantReadWriteLock import java.util.concurrent.locks.ReentrantReadWriteLock
import kafka.common.{NotificationHandler, ZkNodeChangeNotificationListener} import kafka.common.{NotificationHandler, ZkNodeChangeNotificationListener}
import org.apache.zookeeper.Watcher.Event.KeeperState
import kafka.network.RequestChannel.Session import kafka.network.RequestChannel.Session
import kafka.security.auth.SimpleAclAuthorizer.VersionedAcls
import kafka.server.KafkaConfig import kafka.server.KafkaConfig
import kafka.utils.CoreUtils.{inReadLock, inWriteLock} import kafka.utils.CoreUtils.{inReadLock, inWriteLock}
import kafka.utils._ import kafka.utils._
import org.I0Itec.zkclient.IZkStateListener import org.I0Itec.zkclient.exception.{ZkNodeExistsException, ZkNoNodeException}
import org.apache.kafka.common.security.JaasUtils import org.apache.kafka.common.security.JaasUtils
import org.apache.kafka.common.security.auth.KafkaPrincipal import org.apache.kafka.common.security.auth.KafkaPrincipal
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.log4j.Logger import org.apache.log4j.Logger
import scala.util.Random
object SimpleAclAuthorizer { object SimpleAclAuthorizer {
//optional override zookeeper cluster configuration where acls will be stored, if not specified acls will be stored in //optional override zookeeper cluster configuration where acls will be stored, if not specified acls will be stored in
//same zookeeper where all other kafka broker info is stored. //same zookeeper where all other kafka broker info is stored.
@ -62,6 +63,8 @@ object SimpleAclAuthorizer {
//prefix of all the change notification sequence node. //prefix of all the change notification sequence node.
val AclChangedPrefix = "acl_changes_" val AclChangedPrefix = "acl_changes_"
private case class VersionedAcls(acls: Set[Acl], zkVersion: Int)
} }
class SimpleAclAuthorizer extends Authorizer with Logging { class SimpleAclAuthorizer extends Authorizer with Logging {
@ -71,9 +74,16 @@ class SimpleAclAuthorizer extends Authorizer with Logging {
private var zkUtils: ZkUtils = null private var zkUtils: ZkUtils = null
private var aclChangeListener: ZkNodeChangeNotificationListener = null private var aclChangeListener: ZkNodeChangeNotificationListener = null
private val aclCache = new scala.collection.mutable.HashMap[Resource, Set[Acl]] private val aclCache = new scala.collection.mutable.HashMap[Resource, VersionedAcls]
private val lock = new ReentrantReadWriteLock() private val lock = new ReentrantReadWriteLock()
// The maximum number of times we should try to update the resource acls in zookeeper before failing;
// This should never occur, but is a safeguard just in case.
private val maxUpdateRetries = 10
private val retryBackoffMs = 100
private val retryBackoffJitterMs = 50
/** /**
* Guaranteed to be called before any authorize call is made. * Guaranteed to be called before any authorize call is made.
*/ */
@ -164,67 +174,51 @@ class SimpleAclAuthorizer extends Authorizer with Logging {
override def addAcls(acls: Set[Acl], resource: Resource) { override def addAcls(acls: Set[Acl], resource: Resource) {
if (acls != null && acls.nonEmpty) { if (acls != null && acls.nonEmpty) {
val updatedAcls = getAcls(resource) ++ acls inWriteLock(lock) {
val path = toResourcePath(resource) updateResourceAcls(resource) { currentAcls =>
currentAcls ++ acls
if (zkUtils.pathExists(path)) }
zkUtils.updatePersistentPath(path, Json.encode(Acl.toJsonCompatibleMap(updatedAcls))) }
else
zkUtils.createPersistentPath(path, Json.encode(Acl.toJsonCompatibleMap(updatedAcls)))
updateAclChangedFlag(resource)
} }
} }
override def removeAcls(aclsTobeRemoved: Set[Acl], resource: Resource): Boolean = { override def removeAcls(aclsTobeRemoved: Set[Acl], resource: Resource): Boolean = {
if (zkUtils.pathExists(toResourcePath(resource))) { inWriteLock(lock) {
val existingAcls = getAcls(resource) updateResourceAcls(resource) { currentAcls =>
val filteredAcls = existingAcls.filter((acl: Acl) => !aclsTobeRemoved.contains(acl)) currentAcls -- aclsTobeRemoved
val aclNeedsRemoval = (existingAcls != filteredAcls)
if (aclNeedsRemoval) {
val path: String = toResourcePath(resource)
if (filteredAcls.nonEmpty)
zkUtils.updatePersistentPath(path, Json.encode(Acl.toJsonCompatibleMap(filteredAcls)))
else
zkUtils.deletePath(toResourcePath(resource))
updateAclChangedFlag(resource)
} }
}
aclNeedsRemoval
} else false
} }
override def removeAcls(resource: Resource): Boolean = { override def removeAcls(resource: Resource): Boolean = {
if (zkUtils.pathExists(toResourcePath(resource))) { inWriteLock(lock) {
zkUtils.deletePath(toResourcePath(resource)) val result = zkUtils.deletePath(toResourcePath(resource))
updateCache(resource, VersionedAcls(Set(), 0))
updateAclChangedFlag(resource) updateAclChangedFlag(resource)
true result
} else false }
} }
override def getAcls(resource: Resource): Set[Acl] = { override def getAcls(resource: Resource): Set[Acl] = {
inReadLock(lock) { inReadLock(lock) {
aclCache.get(resource).getOrElse(Set.empty[Acl]) aclCache.get(resource).map(_.acls).getOrElse(Set.empty[Acl])
} }
} }
private def getAclsFromZk(resource: Resource): Set[Acl] = {
val aclJson = zkUtils.readDataMaybeNull(toResourcePath(resource))._1
aclJson.map(Acl.fromJson).getOrElse(Set.empty)
}
override def getAcls(principal: KafkaPrincipal): Map[Resource, Set[Acl]] = { override def getAcls(principal: KafkaPrincipal): Map[Resource, Set[Acl]] = {
aclCache.mapValues { acls => inReadLock(lock) {
acls.filter(_.principal == principal) aclCache.mapValues { versionedAcls =>
}.filter { case (_, acls) => versionedAcls.acls.filter(_.principal == principal)
acls.nonEmpty }.filter { case (_, acls) =>
}.toMap acls.nonEmpty
}.toMap
}
} }
override def getAcls(): Map[Resource, Set[Acl]] = { override def getAcls(): Map[Resource, Set[Acl]] = {
aclCache.toMap inReadLock(lock) {
aclCache.mapValues(_.acls).toMap
}
} }
def close() { def close() {
@ -233,25 +227,17 @@ class SimpleAclAuthorizer extends Authorizer with Logging {
} }
private def loadCache() { private def loadCache() {
var acls = Set.empty[Acl]
val resourceTypes = zkUtils.getChildren(SimpleAclAuthorizer.AclZkPath)
for (rType <- resourceTypes) {
val resourceType = ResourceType.fromString(rType)
val resourceTypePath = SimpleAclAuthorizer.AclZkPath + "/" + resourceType.name
val resourceNames = zkUtils.getChildren(resourceTypePath)
for (resourceName <- resourceNames) {
acls = getAclsFromZk(Resource(resourceType, resourceName.toString))
updateCache(new Resource(resourceType, resourceName), acls)
}
}
}
private def updateCache(resource: Resource, acls: Set[Acl]) {
inWriteLock(lock) { inWriteLock(lock) {
if (acls.nonEmpty) val resourceTypes = zkUtils.getChildren(SimpleAclAuthorizer.AclZkPath)
aclCache.put(resource, acls) for (rType <- resourceTypes) {
else val resourceType = ResourceType.fromString(rType)
aclCache.remove(resource) val resourceTypePath = SimpleAclAuthorizer.AclZkPath + "/" + resourceType.name
val resourceNames = zkUtils.getChildren(resourceTypePath)
for (resourceName <- resourceNames) {
val versionedAcls = getAclsFromZk(Resource(resourceType, resourceName.toString))
updateCache(new Resource(resourceType, resourceName), versionedAcls)
}
}
} }
} }
@ -264,16 +250,117 @@ class SimpleAclAuthorizer extends Authorizer with Logging {
authorizerLogger.debug(s"Principal = $principal is $permissionType Operation = $operation from host = $host on resource = $resource") authorizerLogger.debug(s"Principal = $principal is $permissionType Operation = $operation from host = $host on resource = $resource")
} }
/**
* Safely updates the resources ACLs by ensuring reads and writes respect the expected zookeeper version.
* Continues to retry until it succesfully updates zookeeper.
*
* Returns a boolean indicating if the content of the ACLs was actually changed.
*
* @param resource the resource to change ACLs for
* @param getNewAcls function to transform existing acls to new ACLs
* @return boolean indicating if a change was made
*/
private def updateResourceAcls(resource: Resource)(getNewAcls: Set[Acl] => Set[Acl]): Boolean = {
val path = toResourcePath(resource)
var currentVersionedAcls =
if (aclCache.contains(resource))
getAclsFromCache(resource)
else
getAclsFromZk(resource)
var newVersionedAcls: VersionedAcls = null
var writeComplete = false
var retries = 0
while (!writeComplete && retries <= maxUpdateRetries) {
val newAcls = getNewAcls(currentVersionedAcls.acls)
val data = Json.encode(Acl.toJsonCompatibleMap(newAcls))
val (updateSucceeded, updateVersion) =
if (!newAcls.isEmpty) {
updatePath(path, data, currentVersionedAcls.zkVersion)
} else {
trace(s"Deleting path for $resource because it had no ACLs remaining")
(zkUtils.conditionalDeletePath(path, currentVersionedAcls.zkVersion), 0)
}
if (!updateSucceeded) {
trace(s"Failed to update ACLs for $resource. Used version ${currentVersionedAcls.zkVersion}. Reading data and retrying update.")
Thread.sleep(backoffTime)
currentVersionedAcls = getAclsFromZk(resource);
retries += 1
} else {
newVersionedAcls = VersionedAcls(newAcls, updateVersion)
writeComplete = updateSucceeded
}
}
if(!writeComplete)
throw new IllegalStateException(s"Failed to update ACLs for $resource after trying a maximum of $maxUpdateRetries times")
if (newVersionedAcls.acls != currentVersionedAcls.acls) {
debug(s"Updated ACLs for $resource to ${newVersionedAcls.acls} with version ${newVersionedAcls.zkVersion}")
updateCache(resource, newVersionedAcls)
updateAclChangedFlag(resource)
true
} else {
debug(s"Updated ACLs for $resource, no change was made")
updateCache(resource, newVersionedAcls) // Even if no change, update the version
false
}
}
/**
* Updates a zookeeper path with an expected version. If the topic does not exist, it will create it.
* Returns if the update was successful and the new version.
*/
private def updatePath(path: String, data: String, expectedVersion: Int): (Boolean, Int) = {
try {
zkUtils.conditionalUpdatePersistentPathIfExists(path, data, expectedVersion)
} catch {
case e: ZkNoNodeException =>
try {
debug(s"Node $path does not exist, attempting to create it.")
zkUtils.createPersistentPath(path, data)
(true, 0)
} catch {
case e: ZkNodeExistsException =>
debug(s"Failed to create node for $path because it already exists.")
(false, 0)
}
}
}
private def getAclsFromCache(resource: Resource): VersionedAcls = {
aclCache.getOrElse(resource, throw new IllegalArgumentException(s"ACLs do not exist in the cache for resource $resource"))
}
private def getAclsFromZk(resource: Resource): VersionedAcls = {
val (aclJson, stat) = zkUtils.readDataMaybeNull(toResourcePath(resource))
VersionedAcls(aclJson.map(Acl.fromJson).getOrElse(Set()), stat.getVersion)
}
private def updateCache(resource: Resource, versionedAcls: VersionedAcls) {
if (versionedAcls.acls.nonEmpty) {
aclCache.put(resource, versionedAcls)
} else {
aclCache.remove(resource)
}
}
private def updateAclChangedFlag(resource: Resource) { private def updateAclChangedFlag(resource: Resource) {
zkUtils.createSequentialPersistentPath(SimpleAclAuthorizer.AclChangedZkPath + "/" + SimpleAclAuthorizer.AclChangedPrefix, resource.toString) zkUtils.createSequentialPersistentPath(SimpleAclAuthorizer.AclChangedZkPath + "/" + SimpleAclAuthorizer.AclChangedPrefix, resource.toString)
} }
object AclChangedNotificationHandler extends NotificationHandler { private def backoffTime = {
retryBackoffMs + Random.nextInt(retryBackoffJitterMs)
}
object AclChangedNotificationHandler extends NotificationHandler {
override def processNotification(notificationMessage: String) { override def processNotification(notificationMessage: String) {
val resource: Resource = Resource.fromString(notificationMessage) val resource: Resource = Resource.fromString(notificationMessage)
val acls = getAclsFromZk(resource) inWriteLock(lock) {
updateCache(resource, acls) val versionedAcls = getAclsFromZk(resource)
updateCache(resource, versionedAcls)
}
} }
} }
} }

View File

@ -52,12 +52,12 @@ object ZkUtils {
val IsrChangeNotificationPath = "/isr_change_notification" val IsrChangeNotificationPath = "/isr_change_notification"
val EntityConfigPath = "/config" val EntityConfigPath = "/config"
val EntityConfigChangesPath = "/config/changes" val EntityConfigChangesPath = "/config/changes"
def apply(zkUrl: String, sessionTimeout: Int, connectionTimeout: Int, isZkSecurityEnabled: Boolean): ZkUtils = { def apply(zkUrl: String, sessionTimeout: Int, connectionTimeout: Int, isZkSecurityEnabled: Boolean): ZkUtils = {
val (zkClient, zkConnection) = createZkClientAndConnection(zkUrl, sessionTimeout, connectionTimeout) val (zkClient, zkConnection) = createZkClientAndConnection(zkUrl, sessionTimeout, connectionTimeout)
new ZkUtils(zkClient, zkConnection, isZkSecurityEnabled) new ZkUtils(zkClient, zkConnection, isZkSecurityEnabled)
} }
/* /*
* Used in tests * Used in tests
*/ */
@ -75,7 +75,7 @@ object ZkUtils {
val zkClient = new ZkClient(zkConnection, connectionTimeout, ZKStringSerializer) val zkClient = new ZkClient(zkConnection, connectionTimeout, ZKStringSerializer)
(zkClient, zkConnection) (zkClient, zkConnection)
} }
def DefaultAcls(isSecure: Boolean): java.util.List[ACL] = if (isSecure) { def DefaultAcls(isSecure: Boolean): java.util.List[ACL] = if (isSecure) {
val list = new java.util.ArrayList[ACL] val list = new java.util.ArrayList[ACL]
list.addAll(ZooDefs.Ids.CREATOR_ALL_ACL) list.addAll(ZooDefs.Ids.CREATOR_ALL_ACL)
@ -84,7 +84,7 @@ object ZkUtils {
} else { } else {
ZooDefs.Ids.OPEN_ACL_UNSAFE ZooDefs.Ids.OPEN_ACL_UNSAFE
} }
def maybeDeletePath(zkUrl: String, dir: String) { def maybeDeletePath(zkUrl: String, dir: String) {
try { try {
val zk = createZkClient(zkUrl, 30*1000, 30*1000) val zk = createZkClient(zkUrl, 30*1000, 30*1000)
@ -94,7 +94,7 @@ object ZkUtils {
case _: Throwable => // swallow case _: Throwable => // swallow
} }
} }
/* /*
* Get calls that only depend on static paths * Get calls that only depend on static paths
*/ */
@ -111,7 +111,7 @@ object ZkUtils {
def getTopicPartitionLeaderAndIsrPath(topic: String, partitionId: Int): String = def getTopicPartitionLeaderAndIsrPath(topic: String, partitionId: Int): String =
getTopicPartitionPath(topic, partitionId) + "/" + "state" getTopicPartitionPath(topic, partitionId) + "/" + "state"
def getEntityConfigRootPath(entityType: String): String = def getEntityConfigRootPath(entityType: String): String =
ZkUtils.EntityConfigPath + "/" + entityType ZkUtils.EntityConfigPath + "/" + entityType
@ -122,7 +122,7 @@ object ZkUtils {
DeleteTopicsPath + "/" + topic DeleteTopicsPath + "/" + topic
} }
class ZkUtils(val zkClient: ZkClient, class ZkUtils(val zkClient: ZkClient,
val zkConnection: ZkConnection, val zkConnection: ZkConnection,
val isSecure: Boolean) extends Logging { val isSecure: Boolean) extends Logging {
// These are persistent ZK paths that should exist on kafka broker startup. // These are persistent ZK paths that should exist on kafka broker startup.
@ -146,7 +146,7 @@ class ZkUtils(val zkClient: ZkClient,
IsrChangeNotificationPath) IsrChangeNotificationPath)
val DefaultAcls: java.util.List[ACL] = ZkUtils.DefaultAcls(isSecure) val DefaultAcls: java.util.List[ACL] = ZkUtils.DefaultAcls(isSecure)
def getController(): Int = { def getController(): Int = {
readDataMaybeNull(ControllerPath)._1 match { readDataMaybeNull(ControllerPath)._1 match {
case Some(controller) => KafkaController.parseControllerId(controller) case Some(controller) => KafkaController.parseControllerId(controller)
@ -512,6 +512,19 @@ class ZkUtils(val zkClient: ZkClient,
} }
} }
/**
* Conditional delete the persistent path data, return true if it succeeds,
* otherwise (the current version is not the expected version)
*/
def conditionalDeletePath(path: String, expectedVersion: Int): Boolean = {
try {
zkClient.delete(path, expectedVersion)
true
} catch {
case e: KeeperException.BadVersionException => false
}
}
def deletePathRecursive(path: String) { def deletePathRecursive(path: String) {
try { try {
zkClient.deleteRecursive(path) zkClient.deleteRecursive(path)
@ -847,7 +860,7 @@ class ZkUtils(val zkClient: ZkClient,
} }
} }
} }
def close() { def close() {
if(zkClient != null) { if(zkClient != null) {
zkClient.close() zkClient.close()
@ -941,7 +954,7 @@ object ZkPath {
* znode is created and the create call returns OK. If * znode is created and the create call returns OK. If
* the call receives a node exists event, then it checks * the call receives a node exists event, then it checks
* if the session matches. If it does, then it returns OK, * if the session matches. If it does, then it returns OK,
* and otherwise it fails the operation. * and otherwise it fails the operation.
*/ */
class ZKCheckedEphemeral(path: String, class ZKCheckedEphemeral(path: String,
@ -952,7 +965,7 @@ class ZKCheckedEphemeral(path: String,
private val getDataCallback = new GetDataCallback private val getDataCallback = new GetDataCallback
val latch: CountDownLatch = new CountDownLatch(1) val latch: CountDownLatch = new CountDownLatch(1)
var result: Code = Code.OK var result: Code = Code.OK
private class CreateCallback extends StringCallback { private class CreateCallback extends StringCallback {
def processResult(rc: Int, def processResult(rc: Int,
path: String, path: String,
@ -1009,7 +1022,7 @@ class ZKCheckedEphemeral(path: String,
} }
} }
} }
private def createEphemeral() { private def createEphemeral() {
zkHandle.create(path, zkHandle.create(path,
ZKStringSerializer.serialize(data), ZKStringSerializer.serialize(data),
@ -1018,7 +1031,7 @@ class ZKCheckedEphemeral(path: String,
createCallback, createCallback,
null) null)
} }
private def createRecursive(prefix: String, suffix: String) { private def createRecursive(prefix: String, suffix: String) {
debug("Path: %s, Prefix: %s, Suffix: %s".format(path, prefix, suffix)) debug("Path: %s, Prefix: %s, Suffix: %s".format(path, prefix, suffix))
if(suffix.isEmpty()) { if(suffix.isEmpty()) {

View File

@ -17,7 +17,7 @@
package kafka.security.auth package kafka.security.auth
import java.net.InetAddress import java.net.InetAddress
import java.util.UUID import java.util.{UUID}
import kafka.network.RequestChannel.Session import kafka.network.RequestChannel.Session
import kafka.security.auth.Acl.WildCardHost import kafka.security.auth.Acl.WildCardHost
@ -31,6 +31,7 @@ import org.junit.{After, Before, Test}
class SimpleAclAuthorizerTest extends ZooKeeperTestHarness { class SimpleAclAuthorizerTest extends ZooKeeperTestHarness {
val simpleAclAuthorizer = new SimpleAclAuthorizer val simpleAclAuthorizer = new SimpleAclAuthorizer
val simpleAclAuthorizer2 = new SimpleAclAuthorizer
val testPrincipal = Acl.WildCardPrincipal val testPrincipal = Acl.WildCardPrincipal
val testHostName = InetAddress.getByName("192.168.0.1") val testHostName = InetAddress.getByName("192.168.0.1")
val session = new Session(testPrincipal, testHostName) val session = new Session(testPrincipal, testHostName)
@ -48,12 +49,14 @@ class SimpleAclAuthorizerTest extends ZooKeeperTestHarness {
config = KafkaConfig.fromProps(props) config = KafkaConfig.fromProps(props)
simpleAclAuthorizer.configure(config.originals) simpleAclAuthorizer.configure(config.originals)
simpleAclAuthorizer2.configure(config.originals)
resource = new Resource(Topic, UUID.randomUUID().toString) resource = new Resource(Topic, UUID.randomUUID().toString)
} }
@After @After
override def tearDown(): Unit = { override def tearDown(): Unit = {
simpleAclAuthorizer.close() simpleAclAuthorizer.close()
simpleAclAuthorizer2.close()
} }
@Test @Test
@ -254,6 +257,87 @@ class SimpleAclAuthorizerTest extends ZooKeeperTestHarness {
assertEquals(acls1, authorizer.getAcls(resource1)) assertEquals(acls1, authorizer.getAcls(resource1))
} }
@Test
def testLocalConcurrentModificationOfResourceAcls() {
val commonResource = new Resource(Topic, "test")
val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username)
val acl1 = new Acl(user1, Allow, WildCardHost, Read)
val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob")
val acl2 = new Acl(user2, Deny, WildCardHost, Read)
simpleAclAuthorizer.addAcls(Set(acl1), commonResource)
simpleAclAuthorizer.addAcls(Set(acl2), commonResource)
TestUtils.waitAndVerifyAcls(Set(acl1, acl2), simpleAclAuthorizer, commonResource)
}
@Test
def testDistributedConcurrentModificationOfResourceAcls() {
val commonResource = new Resource(Topic, "test")
val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username)
val acl1 = new Acl(user1, Allow, WildCardHost, Read)
val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob")
val acl2 = new Acl(user2, Deny, WildCardHost, Read)
// Add on each instance
simpleAclAuthorizer.addAcls(Set(acl1), commonResource)
simpleAclAuthorizer2.addAcls(Set(acl2), commonResource)
TestUtils.waitAndVerifyAcls(Set(acl1, acl2), simpleAclAuthorizer, commonResource)
TestUtils.waitAndVerifyAcls(Set(acl1, acl2), simpleAclAuthorizer2, commonResource)
val user3 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "joe")
val acl3 = new Acl(user3, Deny, WildCardHost, Read)
// Add on one instance and delete on another
simpleAclAuthorizer.addAcls(Set(acl3), commonResource)
val deleted = simpleAclAuthorizer2.removeAcls(Set(acl3), commonResource)
assertTrue("The authorizer should see a value that needs to be deleted", deleted)
TestUtils.waitAndVerifyAcls(Set(acl1, acl2), simpleAclAuthorizer, commonResource)
TestUtils.waitAndVerifyAcls(Set(acl1, acl2), simpleAclAuthorizer2, commonResource)
}
@Test
def testHighConcurrencyModificationOfResourceAcls() {
val commonResource = new Resource(Topic, "test")
val acls = (0 to 100).map { i =>
val useri = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, i.toString)
new Acl(useri, Allow, WildCardHost, Read)
}
// Alternate authorizer, Remove all acls that end in 0
val concurrentFuctions = acls.map { acl =>
() => {
val aclId = acl.principal.getName.toInt
if (aclId % 2 == 0) {
simpleAclAuthorizer.addAcls(Set(acl), commonResource)
} else {
simpleAclAuthorizer2.addAcls(Set(acl), commonResource)
}
if (aclId % 10 == 0) {
simpleAclAuthorizer2.removeAcls(Set(acl), commonResource)
}
}
}
val expectedAcls = acls.filter { acl =>
val aclId = acl.principal.getName.toInt
aclId % 10 != 0
}.toSet
TestUtils.assertConcurrent("Should support many concurrent calls", concurrentFuctions, 15000)
TestUtils.waitAndVerifyAcls(expectedAcls, simpleAclAuthorizer, commonResource)
TestUtils.waitAndVerifyAcls(expectedAcls, simpleAclAuthorizer2, commonResource)
}
private def changeAclAndVerify(originalAcls: Set[Acl], addedAcls: Set[Acl], removedAcls: Set[Acl], resource: Resource = resource): Set[Acl] = { private def changeAclAndVerify(originalAcls: Set[Acl], addedAcls: Set[Acl], removedAcls: Set[Acl], resource: Resource = resource): Set[Acl] = {
var acls = originalAcls var acls = originalAcls

View File

@ -21,8 +21,9 @@ import java.io._
import java.nio._ import java.nio._
import java.nio.file.Files import java.nio.file.Files
import java.nio.channels._ import java.nio.channels._
import java.util.Random import java.util
import java.util.Properties import java.util.concurrent.{Callable, TimeUnit, Executors}
import java.util.{Collections, Random, Properties}
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import javax.net.ssl.X509TrustManager import javax.net.ssl.X509TrustManager
import charset.Charset import charset.Charset
@ -54,6 +55,7 @@ import org.apache.kafka.common.serialization.{ByteArraySerializer, Serializer}
import scala.collection.Map import scala.collection.Map
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
/** /**
* Utility functions to help with testing * Utility functions to help with testing
@ -131,6 +133,7 @@ object TestUtils extends Logging {
/** /**
* Create a kafka server instance with appropriate test settings * Create a kafka server instance with appropriate test settings
* USING THIS IS A SIGN YOU ARE NOT WRITING A REAL UNIT TEST * USING THIS IS A SIGN YOU ARE NOT WRITING A REAL UNIT TEST
*
* @param config The configuration of the server * @param config The configuration of the server
*/ */
def createServer(config: KafkaConfig, time: Time = SystemTime): KafkaServer = { def createServer(config: KafkaConfig, time: Time = SystemTime): KafkaServer = {
@ -141,7 +144,7 @@ object TestUtils extends Logging {
/** /**
* Create a test config for the provided parameters. * Create a test config for the provided parameters.
* *
* Note that if `interBrokerSecurityProtocol` is defined, the listener for the `SecurityProtocol` will be enabled. * Note that if `interBrokerSecurityProtocol` is defined, the listener for the `SecurityProtocol` will be enabled.
*/ */
def createBrokerConfigs(numConfigs: Int, def createBrokerConfigs(numConfigs: Int,
@ -281,6 +284,7 @@ object TestUtils extends Logging {
/** /**
* Wrap the message in a message set * Wrap the message in a message set
*
* @param payload The bytes of the message * @param payload The bytes of the message
*/ */
def singleMessageSet(payload: Array[Byte], def singleMessageSet(payload: Array[Byte],
@ -291,6 +295,7 @@ object TestUtils extends Logging {
/** /**
* Generate an array of random bytes * Generate an array of random bytes
*
* @param numBytes The size of the array * @param numBytes The size of the array
*/ */
def randomBytes(numBytes: Int): Array[Byte] = { def randomBytes(numBytes: Int): Array[Byte] = {
@ -301,6 +306,7 @@ object TestUtils extends Logging {
/** /**
* Generate a random string of letters and digits of the given length * Generate a random string of letters and digits of the given length
*
* @param len The length of the string * @param len The length of the string
* @return The random string * @return The random string
*/ */
@ -679,6 +685,7 @@ object TestUtils extends Logging {
* If neither oldLeaderOpt nor newLeaderOpt is defined, wait until the leader of a partition is elected. * If neither oldLeaderOpt nor newLeaderOpt is defined, wait until the leader of a partition is elected.
* If oldLeaderOpt is defined, it waits until the new leader is different from the old leader. * If oldLeaderOpt is defined, it waits until the new leader is different from the old leader.
* If newLeaderOpt is defined, it waits until the new leader becomes the expected new leader. * If newLeaderOpt is defined, it waits until the new leader becomes the expected new leader.
*
* @return The new leader or assertion failure if timeout is reached. * @return The new leader or assertion failure if timeout is reached.
*/ */
def waitUntilLeaderIsElectedOrChanged(zkUtils: ZkUtils, topic: String, partition: Int, timeoutMs: Long = 5000L, def waitUntilLeaderIsElectedOrChanged(zkUtils: ZkUtils, topic: String, partition: Int, timeoutMs: Long = 5000L,
@ -786,6 +793,7 @@ object TestUtils extends Logging {
/** /**
* Wait until a valid leader is propagated to the metadata cache in each broker. * Wait until a valid leader is propagated to the metadata cache in each broker.
* It assumes that the leader propagated to each broker is the same. * It assumes that the leader propagated to each broker is the same.
*
* @param servers The list of servers that the metadata should reach to * @param servers The list of servers that the metadata should reach to
* @param topic The topic name * @param topic The topic name
* @param partition The partition Id * @param partition The partition Id
@ -812,7 +820,7 @@ object TestUtils extends Logging {
} }
def waitUntilLeaderIsKnown(servers: Seq[KafkaServer], topic: String, partition: Int, timeout: Long = 5000L): Unit = { def waitUntilLeaderIsKnown(servers: Seq[KafkaServer], topic: String, partition: Int, timeout: Long = 5000L): Unit = {
TestUtils.waitUntilTrue(() => TestUtils.waitUntilTrue(() =>
servers.exists { server => servers.exists { server =>
server.replicaManager.getPartition(topic, partition).exists(_.leaderReplicaIfLocal().isDefined) server.replicaManager.getPartition(topic, partition).exists(_.leaderReplicaIfLocal().isDefined)
}, },
@ -968,12 +976,11 @@ object TestUtils extends Logging {
/** /**
* Consume all messages (or a specific number of messages) * Consume all messages (or a specific number of messages)
*
* @param topicMessageStreams the Topic Message Streams * @param topicMessageStreams the Topic Message Streams
* @param nMessagesPerThread an optional field to specify the exact number of messages to be returned. * @param nMessagesPerThread an optional field to specify the exact number of messages to be returned.
* ConsumerTimeoutException will be thrown if there are no messages to be consumed. * ConsumerTimeoutException will be thrown if there are no messages to be consumed.
* If not specified, then all available messages will be consumed, and no exception is thrown. * If not specified, then all available messages will be consumed, and no exception is thrown.
*
*
* @return the list of messages consumed. * @return the list of messages consumed.
*/ */
def getMessages(topicMessageStreams: Map[String, List[KafkaStream[String, String]]], def getMessages(topicMessageStreams: Map[String, List[KafkaStream[String, String]]],
@ -1033,6 +1040,7 @@ object TestUtils extends Logging {
/** /**
* Translate the given buffer into a string * Translate the given buffer into a string
*
* @param buffer The buffer to translate * @param buffer The buffer to translate
* @param encoding The encoding to use in translating bytes to characters * @param encoding The encoding to use in translating bytes to characters
*/ */
@ -1075,6 +1083,46 @@ object TestUtils extends Logging {
s"expected acls $expected but got ${authorizer.getAcls(resource)}", waitTime = 10000) s"expected acls $expected but got ${authorizer.getAcls(resource)}", waitTime = 10000)
} }
/**
* To use this you pass in a sequence of functions that are your arrange/act/assert test on the SUT.
* They all run at the same time in the assertConcurrent method; the chances of triggering a multithreading code error,
* and thereby failing some assertion are greatly increased.
*/
def assertConcurrent(message: String, functions: Seq[() => Any], timeoutMs: Int) {
def failWithTimeout() {
fail(s"$message. Timed out, the concurrent functions took more than $timeoutMs milliseconds")
}
val numThreads = functions.size
val threadPool = Executors.newFixedThreadPool(numThreads)
val exceptions = ArrayBuffer[Throwable]()
try {
val runnables = functions.map { function =>
new Callable[Unit] {
override def call(): Unit = function()
}
}.asJava
val futures = threadPool.invokeAll(runnables, timeoutMs, TimeUnit.MILLISECONDS).asScala
futures.foreach { future =>
if (future.isCancelled)
failWithTimeout()
else
try future.get()
catch { case e: Exception =>
exceptions += e
}
}
} catch {
case ie: InterruptedException => failWithTimeout()
case e => exceptions += e
} finally {
threadPool.shutdownNow()
}
assertTrue(s"$message failed with exception(s) $exceptions", exceptions.isEmpty)
}
} }
class IntEncoder(props: VerifiableProperties = null) extends Encoder[Int] { class IntEncoder(props: VerifiableProperties = null) extends Encoder[Int] {