Repository: kafka Updated Branches: refs/heads/trunk eb823281a -> bfac36ad0
KAFKA-3328: SimpleAclAuthorizer can lose ACLs with frequent add/remov⦠â¦e calls Changes the SimpleAclAuthorizer to: - Track and utilize the zookeeper version when updating zookeeper to prevent data loss in the case of stale reads and race conditions - Update local cache when modifying ACLs - Add debug logging Author: Grant Henke <[email protected]> Author: Grant Henke <[email protected]> Author: Ismael Juma <[email protected]> Reviewers: Flavio Junqueira, Jun Rao, Ismael Juma, Gwen Shapira Closes #1006 from granthenke/simple-authorizer-fix Project: http://git-wip-us.apache.org/repos/asf/kafka/repo Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/bfac36ad Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/bfac36ad Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/bfac36ad Branch: refs/heads/trunk Commit: bfac36ad0e378b5f39e3889e40a75c5c1fc48fa7 Parents: eb82328 Author: Grant Henke <[email protected]> Authored: Sun Mar 20 00:46:12 2016 -0700 Committer: Gwen Shapira <[email protected]> Committed: Sun Mar 20 00:46:12 2016 -0700 ---------------------------------------------------------------------- .../security/auth/SimpleAclAuthorizer.scala | 219 +++++++++++++------ core/src/main/scala/kafka/utils/ZkUtils.scala | 39 ++-- .../security/auth/SimpleAclAuthorizerTest.scala | 86 +++++++- .../test/scala/unit/kafka/utils/TestUtils.scala | 60 ++++- 4 files changed, 318 insertions(+), 86 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/kafka/blob/bfac36ad/core/src/main/scala/kafka/security/auth/SimpleAclAuthorizer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/security/auth/SimpleAclAuthorizer.scala b/core/src/main/scala/kafka/security/auth/SimpleAclAuthorizer.scala index 77e23f8..1a06af2 100644 --- a/core/src/main/scala/kafka/security/auth/SimpleAclAuthorizer.scala +++ b/core/src/main/scala/kafka/security/auth/SimpleAclAuthorizer.scala @@ -19,19 +19,20 @@ package kafka.security.auth import java.util import java.util.concurrent.locks.ReentrantReadWriteLock import kafka.common.{NotificationHandler, ZkNodeChangeNotificationListener} -import org.apache.zookeeper.Watcher.Event.KeeperState - import kafka.network.RequestChannel.Session +import kafka.security.auth.SimpleAclAuthorizer.VersionedAcls import kafka.server.KafkaConfig import kafka.utils.CoreUtils.{inReadLock, inWriteLock} import kafka.utils._ -import org.I0Itec.zkclient.IZkStateListener +import org.I0Itec.zkclient.exception.{ZkNodeExistsException, ZkNoNodeException} import org.apache.kafka.common.security.JaasUtils import org.apache.kafka.common.security.auth.KafkaPrincipal import scala.collection.JavaConverters._ import org.apache.log4j.Logger +import scala.util.Random + object SimpleAclAuthorizer { //optional override zookeeper cluster configuration where acls will be stored, if not specified acls will be stored in //same zookeeper where all other kafka broker info is stored. @@ -62,6 +63,8 @@ object SimpleAclAuthorizer { //prefix of all the change notification sequence node. val AclChangedPrefix = "acl_changes_" + + private case class VersionedAcls(acls: Set[Acl], zkVersion: Int) } class SimpleAclAuthorizer extends Authorizer with Logging { @@ -71,9 +74,16 @@ class SimpleAclAuthorizer extends Authorizer with Logging { private var zkUtils: ZkUtils = null private var aclChangeListener: ZkNodeChangeNotificationListener = null - private val aclCache = new scala.collection.mutable.HashMap[Resource, Set[Acl]] + private val aclCache = new scala.collection.mutable.HashMap[Resource, VersionedAcls] private val lock = new ReentrantReadWriteLock() + // The maximum number of times we should try to update the resource acls in zookeeper before failing; + // This should never occur, but is a safeguard just in case. + private val maxUpdateRetries = 10 + + private val retryBackoffMs = 100 + private val retryBackoffJitterMs = 50 + /** * Guaranteed to be called before any authorize call is made. */ @@ -164,67 +174,51 @@ class SimpleAclAuthorizer extends Authorizer with Logging { override def addAcls(acls: Set[Acl], resource: Resource) { if (acls != null && acls.nonEmpty) { - val updatedAcls = getAcls(resource) ++ acls - val path = toResourcePath(resource) - - if (zkUtils.pathExists(path)) - zkUtils.updatePersistentPath(path, Json.encode(Acl.toJsonCompatibleMap(updatedAcls))) - else - zkUtils.createPersistentPath(path, Json.encode(Acl.toJsonCompatibleMap(updatedAcls))) - - updateAclChangedFlag(resource) + inWriteLock(lock) { + updateResourceAcls(resource) { currentAcls => + currentAcls ++ acls + } + } } } override def removeAcls(aclsTobeRemoved: Set[Acl], resource: Resource): Boolean = { - if (zkUtils.pathExists(toResourcePath(resource))) { - val existingAcls = getAcls(resource) - val filteredAcls = existingAcls.filter((acl: Acl) => !aclsTobeRemoved.contains(acl)) - - val aclNeedsRemoval = (existingAcls != filteredAcls) - if (aclNeedsRemoval) { - val path: String = toResourcePath(resource) - if (filteredAcls.nonEmpty) - zkUtils.updatePersistentPath(path, Json.encode(Acl.toJsonCompatibleMap(filteredAcls))) - else - zkUtils.deletePath(toResourcePath(resource)) - - updateAclChangedFlag(resource) + inWriteLock(lock) { + updateResourceAcls(resource) { currentAcls => + currentAcls -- aclsTobeRemoved } - - aclNeedsRemoval - } else false + } } override def removeAcls(resource: Resource): Boolean = { - if (zkUtils.pathExists(toResourcePath(resource))) { - zkUtils.deletePath(toResourcePath(resource)) + inWriteLock(lock) { + val result = zkUtils.deletePath(toResourcePath(resource)) + updateCache(resource, VersionedAcls(Set(), 0)) updateAclChangedFlag(resource) - true - } else false + result + } } override def getAcls(resource: Resource): Set[Acl] = { inReadLock(lock) { - aclCache.get(resource).getOrElse(Set.empty[Acl]) + aclCache.get(resource).map(_.acls).getOrElse(Set.empty[Acl]) } } - private def getAclsFromZk(resource: Resource): Set[Acl] = { - val aclJson = zkUtils.readDataMaybeNull(toResourcePath(resource))._1 - aclJson.map(Acl.fromJson).getOrElse(Set.empty) - } - override def getAcls(principal: KafkaPrincipal): Map[Resource, Set[Acl]] = { - aclCache.mapValues { acls => - acls.filter(_.principal == principal) - }.filter { case (_, acls) => - acls.nonEmpty - }.toMap + inReadLock(lock) { + aclCache.mapValues { versionedAcls => + versionedAcls.acls.filter(_.principal == principal) + }.filter { case (_, acls) => + acls.nonEmpty + }.toMap + } } override def getAcls(): Map[Resource, Set[Acl]] = { - aclCache.toMap + inReadLock(lock) { + aclCache.mapValues(_.acls).toMap + } } def close() { @@ -233,25 +227,17 @@ class SimpleAclAuthorizer extends Authorizer with Logging { } private def loadCache() { - var acls = Set.empty[Acl] - val resourceTypes = zkUtils.getChildren(SimpleAclAuthorizer.AclZkPath) - for (rType <- resourceTypes) { - val resourceType = ResourceType.fromString(rType) - val resourceTypePath = SimpleAclAuthorizer.AclZkPath + "/" + resourceType.name - val resourceNames = zkUtils.getChildren(resourceTypePath) - for (resourceName <- resourceNames) { - acls = getAclsFromZk(Resource(resourceType, resourceName.toString)) - updateCache(new Resource(resourceType, resourceName), acls) - } - } - } - - private def updateCache(resource: Resource, acls: Set[Acl]) { inWriteLock(lock) { - if (acls.nonEmpty) - aclCache.put(resource, acls) - else - aclCache.remove(resource) + val resourceTypes = zkUtils.getChildren(SimpleAclAuthorizer.AclZkPath) + for (rType <- resourceTypes) { + val resourceType = ResourceType.fromString(rType) + val resourceTypePath = SimpleAclAuthorizer.AclZkPath + "/" + resourceType.name + val resourceNames = zkUtils.getChildren(resourceTypePath) + for (resourceName <- resourceNames) { + val versionedAcls = getAclsFromZk(Resource(resourceType, resourceName.toString)) + updateCache(new Resource(resourceType, resourceName), versionedAcls) + } + } } } @@ -264,16 +250,117 @@ class SimpleAclAuthorizer extends Authorizer with Logging { authorizerLogger.debug(s"Principal = $principal is $permissionType Operation = $operation from host = $host on resource = $resource") } + /** + * Safely updates the resources ACLs by ensuring reads and writes respect the expected zookeeper version. + * Continues to retry until it succesfully updates zookeeper. + * + * Returns a boolean indicating if the content of the ACLs was actually changed. + * + * @param resource the resource to change ACLs for + * @param getNewAcls function to transform existing acls to new ACLs + * @return boolean indicating if a change was made + */ + private def updateResourceAcls(resource: Resource)(getNewAcls: Set[Acl] => Set[Acl]): Boolean = { + val path = toResourcePath(resource) + + var currentVersionedAcls = + if (aclCache.contains(resource)) + getAclsFromCache(resource) + else + getAclsFromZk(resource) + var newVersionedAcls: VersionedAcls = null + var writeComplete = false + var retries = 0 + while (!writeComplete && retries <= maxUpdateRetries) { + val newAcls = getNewAcls(currentVersionedAcls.acls) + val data = Json.encode(Acl.toJsonCompatibleMap(newAcls)) + val (updateSucceeded, updateVersion) = + if (!newAcls.isEmpty) { + updatePath(path, data, currentVersionedAcls.zkVersion) + } else { + trace(s"Deleting path for $resource because it had no ACLs remaining") + (zkUtils.conditionalDeletePath(path, currentVersionedAcls.zkVersion), 0) + } + + if (!updateSucceeded) { + trace(s"Failed to update ACLs for $resource. Used version ${currentVersionedAcls.zkVersion}. Reading data and retrying update.") + Thread.sleep(backoffTime) + currentVersionedAcls = getAclsFromZk(resource); + retries += 1 + } else { + newVersionedAcls = VersionedAcls(newAcls, updateVersion) + writeComplete = updateSucceeded + } + } + + if(!writeComplete) + throw new IllegalStateException(s"Failed to update ACLs for $resource after trying a maximum of $maxUpdateRetries times") + + if (newVersionedAcls.acls != currentVersionedAcls.acls) { + debug(s"Updated ACLs for $resource to ${newVersionedAcls.acls} with version ${newVersionedAcls.zkVersion}") + updateCache(resource, newVersionedAcls) + updateAclChangedFlag(resource) + true + } else { + debug(s"Updated ACLs for $resource, no change was made") + updateCache(resource, newVersionedAcls) // Even if no change, update the version + false + } + } + + /** + * Updates a zookeeper path with an expected version. If the topic does not exist, it will create it. + * Returns if the update was successful and the new version. + */ + private def updatePath(path: String, data: String, expectedVersion: Int): (Boolean, Int) = { + try { + zkUtils.conditionalUpdatePersistentPathIfExists(path, data, expectedVersion) + } catch { + case e: ZkNoNodeException => + try { + debug(s"Node $path does not exist, attempting to create it.") + zkUtils.createPersistentPath(path, data) + (true, 0) + } catch { + case e: ZkNodeExistsException => + debug(s"Failed to create node for $path because it already exists.") + (false, 0) + } + } + } + + private def getAclsFromCache(resource: Resource): VersionedAcls = { + aclCache.getOrElse(resource, throw new IllegalArgumentException(s"ACLs do not exist in the cache for resource $resource")) + } + + private def getAclsFromZk(resource: Resource): VersionedAcls = { + val (aclJson, stat) = zkUtils.readDataMaybeNull(toResourcePath(resource)) + VersionedAcls(aclJson.map(Acl.fromJson).getOrElse(Set()), stat.getVersion) + } + + private def updateCache(resource: Resource, versionedAcls: VersionedAcls) { + if (versionedAcls.acls.nonEmpty) { + aclCache.put(resource, versionedAcls) + } else { + aclCache.remove(resource) + } + } + private def updateAclChangedFlag(resource: Resource) { zkUtils.createSequentialPersistentPath(SimpleAclAuthorizer.AclChangedZkPath + "/" + SimpleAclAuthorizer.AclChangedPrefix, resource.toString) } - object AclChangedNotificationHandler extends NotificationHandler { + private def backoffTime = { + retryBackoffMs + Random.nextInt(retryBackoffJitterMs) + } + object AclChangedNotificationHandler extends NotificationHandler { override def processNotification(notificationMessage: String) { val resource: Resource = Resource.fromString(notificationMessage) - val acls = getAclsFromZk(resource) - updateCache(resource, acls) + inWriteLock(lock) { + val versionedAcls = getAclsFromZk(resource) + updateCache(resource, versionedAcls) + } } } } http://git-wip-us.apache.org/repos/asf/kafka/blob/bfac36ad/core/src/main/scala/kafka/utils/ZkUtils.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/utils/ZkUtils.scala b/core/src/main/scala/kafka/utils/ZkUtils.scala index 99c8196..49d3cfa 100644 --- a/core/src/main/scala/kafka/utils/ZkUtils.scala +++ b/core/src/main/scala/kafka/utils/ZkUtils.scala @@ -52,12 +52,12 @@ object ZkUtils { val IsrChangeNotificationPath = "/isr_change_notification" val EntityConfigPath = "/config" val EntityConfigChangesPath = "/config/changes" - + def apply(zkUrl: String, sessionTimeout: Int, connectionTimeout: Int, isZkSecurityEnabled: Boolean): ZkUtils = { val (zkClient, zkConnection) = createZkClientAndConnection(zkUrl, sessionTimeout, connectionTimeout) new ZkUtils(zkClient, zkConnection, isZkSecurityEnabled) } - + /* * Used in tests */ @@ -75,7 +75,7 @@ object ZkUtils { val zkClient = new ZkClient(zkConnection, connectionTimeout, ZKStringSerializer) (zkClient, zkConnection) } - + def DefaultAcls(isSecure: Boolean): java.util.List[ACL] = if (isSecure) { val list = new java.util.ArrayList[ACL] list.addAll(ZooDefs.Ids.CREATOR_ALL_ACL) @@ -84,7 +84,7 @@ object ZkUtils { } else { ZooDefs.Ids.OPEN_ACL_UNSAFE } - + def maybeDeletePath(zkUrl: String, dir: String) { try { val zk = createZkClient(zkUrl, 30*1000, 30*1000) @@ -94,7 +94,7 @@ object ZkUtils { case _: Throwable => // swallow } } - + /* * Get calls that only depend on static paths */ @@ -111,7 +111,7 @@ object ZkUtils { def getTopicPartitionLeaderAndIsrPath(topic: String, partitionId: Int): String = getTopicPartitionPath(topic, partitionId) + "/" + "state" - + def getEntityConfigRootPath(entityType: String): String = ZkUtils.EntityConfigPath + "/" + entityType @@ -122,7 +122,7 @@ object ZkUtils { DeleteTopicsPath + "/" + topic } -class ZkUtils(val zkClient: ZkClient, +class ZkUtils(val zkClient: ZkClient, val zkConnection: ZkConnection, val isSecure: Boolean) extends Logging { // These are persistent ZK paths that should exist on kafka broker startup. @@ -146,7 +146,7 @@ class ZkUtils(val zkClient: ZkClient, IsrChangeNotificationPath) val DefaultAcls: java.util.List[ACL] = ZkUtils.DefaultAcls(isSecure) - + def getController(): Int = { readDataMaybeNull(ControllerPath)._1 match { case Some(controller) => KafkaController.parseControllerId(controller) @@ -512,6 +512,19 @@ class ZkUtils(val zkClient: ZkClient, } } + /** + * Conditional delete the persistent path data, return true if it succeeds, + * otherwise (the current version is not the expected version) + */ + def conditionalDeletePath(path: String, expectedVersion: Int): Boolean = { + try { + zkClient.delete(path, expectedVersion) + true + } catch { + case e: KeeperException.BadVersionException => false + } + } + def deletePathRecursive(path: String) { try { zkClient.deleteRecursive(path) @@ -847,7 +860,7 @@ class ZkUtils(val zkClient: ZkClient, } } } - + def close() { if(zkClient != null) { zkClient.close() @@ -941,7 +954,7 @@ object ZkPath { * znode is created and the create call returns OK. If * the call receives a node exists event, then it checks * if the session matches. If it does, then it returns OK, - * and otherwise it fails the operation. + * and otherwise it fails the operation. */ class ZKCheckedEphemeral(path: String, @@ -952,7 +965,7 @@ class ZKCheckedEphemeral(path: String, private val getDataCallback = new GetDataCallback val latch: CountDownLatch = new CountDownLatch(1) var result: Code = Code.OK - + private class CreateCallback extends StringCallback { def processResult(rc: Int, path: String, @@ -1009,7 +1022,7 @@ class ZKCheckedEphemeral(path: String, } } } - + private def createEphemeral() { zkHandle.create(path, ZKStringSerializer.serialize(data), @@ -1018,7 +1031,7 @@ class ZKCheckedEphemeral(path: String, createCallback, null) } - + private def createRecursive(prefix: String, suffix: String) { debug("Path: %s, Prefix: %s, Suffix: %s".format(path, prefix, suffix)) if(suffix.isEmpty()) { http://git-wip-us.apache.org/repos/asf/kafka/blob/bfac36ad/core/src/test/scala/unit/kafka/security/auth/SimpleAclAuthorizerTest.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/unit/kafka/security/auth/SimpleAclAuthorizerTest.scala b/core/src/test/scala/unit/kafka/security/auth/SimpleAclAuthorizerTest.scala index efcf930..bdadb15 100644 --- a/core/src/test/scala/unit/kafka/security/auth/SimpleAclAuthorizerTest.scala +++ b/core/src/test/scala/unit/kafka/security/auth/SimpleAclAuthorizerTest.scala @@ -17,7 +17,7 @@ package kafka.security.auth import java.net.InetAddress -import java.util.UUID +import java.util.{UUID} import kafka.network.RequestChannel.Session import kafka.security.auth.Acl.WildCardHost @@ -31,6 +31,7 @@ import org.junit.{After, Before, Test} class SimpleAclAuthorizerTest extends ZooKeeperTestHarness { val simpleAclAuthorizer = new SimpleAclAuthorizer + val simpleAclAuthorizer2 = new SimpleAclAuthorizer val testPrincipal = Acl.WildCardPrincipal val testHostName = InetAddress.getByName("192.168.0.1") val session = new Session(testPrincipal, testHostName) @@ -48,12 +49,14 @@ class SimpleAclAuthorizerTest extends ZooKeeperTestHarness { config = KafkaConfig.fromProps(props) simpleAclAuthorizer.configure(config.originals) + simpleAclAuthorizer2.configure(config.originals) resource = new Resource(Topic, UUID.randomUUID().toString) } @After override def tearDown(): Unit = { simpleAclAuthorizer.close() + simpleAclAuthorizer2.close() } @Test @@ -254,6 +257,87 @@ class SimpleAclAuthorizerTest extends ZooKeeperTestHarness { assertEquals(acls1, authorizer.getAcls(resource1)) } + @Test + def testLocalConcurrentModificationOfResourceAcls() { + val commonResource = new Resource(Topic, "test") + + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val acl1 = new Acl(user1, Allow, WildCardHost, Read) + + val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob") + val acl2 = new Acl(user2, Deny, WildCardHost, Read) + + simpleAclAuthorizer.addAcls(Set(acl1), commonResource) + simpleAclAuthorizer.addAcls(Set(acl2), commonResource) + + TestUtils.waitAndVerifyAcls(Set(acl1, acl2), simpleAclAuthorizer, commonResource) + } + + @Test + def testDistributedConcurrentModificationOfResourceAcls() { + val commonResource = new Resource(Topic, "test") + + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val acl1 = new Acl(user1, Allow, WildCardHost, Read) + + val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob") + val acl2 = new Acl(user2, Deny, WildCardHost, Read) + + // Add on each instance + simpleAclAuthorizer.addAcls(Set(acl1), commonResource) + simpleAclAuthorizer2.addAcls(Set(acl2), commonResource) + + TestUtils.waitAndVerifyAcls(Set(acl1, acl2), simpleAclAuthorizer, commonResource) + TestUtils.waitAndVerifyAcls(Set(acl1, acl2), simpleAclAuthorizer2, commonResource) + + val user3 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "joe") + val acl3 = new Acl(user3, Deny, WildCardHost, Read) + + // Add on one instance and delete on another + simpleAclAuthorizer.addAcls(Set(acl3), commonResource) + val deleted = simpleAclAuthorizer2.removeAcls(Set(acl3), commonResource) + + assertTrue("The authorizer should see a value that needs to be deleted", deleted) + + TestUtils.waitAndVerifyAcls(Set(acl1, acl2), simpleAclAuthorizer, commonResource) + TestUtils.waitAndVerifyAcls(Set(acl1, acl2), simpleAclAuthorizer2, commonResource) + } + + @Test + def testHighConcurrencyModificationOfResourceAcls() { + val commonResource = new Resource(Topic, "test") + + val acls = (0 to 100).map { i => + val useri = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, i.toString) + new Acl(useri, Allow, WildCardHost, Read) + } + + // Alternate authorizer, Remove all acls that end in 0 + val concurrentFuctions = acls.map { acl => + () => { + val aclId = acl.principal.getName.toInt + if (aclId % 2 == 0) { + simpleAclAuthorizer.addAcls(Set(acl), commonResource) + } else { + simpleAclAuthorizer2.addAcls(Set(acl), commonResource) + } + if (aclId % 10 == 0) { + simpleAclAuthorizer2.removeAcls(Set(acl), commonResource) + } + } + } + + val expectedAcls = acls.filter { acl => + val aclId = acl.principal.getName.toInt + aclId % 10 != 0 + }.toSet + + TestUtils.assertConcurrent("Should support many concurrent calls", concurrentFuctions, 15000) + + TestUtils.waitAndVerifyAcls(expectedAcls, simpleAclAuthorizer, commonResource) + TestUtils.waitAndVerifyAcls(expectedAcls, simpleAclAuthorizer2, commonResource) + } + private def changeAclAndVerify(originalAcls: Set[Acl], addedAcls: Set[Acl], removedAcls: Set[Acl], resource: Resource = resource): Set[Acl] = { var acls = originalAcls http://git-wip-us.apache.org/repos/asf/kafka/blob/bfac36ad/core/src/test/scala/unit/kafka/utils/TestUtils.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala index 7b3e955..0730468 100755 --- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala +++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala @@ -21,8 +21,9 @@ import java.io._ import java.nio._ import java.nio.file.Files import java.nio.channels._ -import java.util.Random -import java.util.Properties +import java.util +import java.util.concurrent.{Callable, TimeUnit, Executors} +import java.util.{Collections, Random, Properties} import java.security.cert.X509Certificate import javax.net.ssl.X509TrustManager import charset.Charset @@ -54,6 +55,7 @@ import org.apache.kafka.common.serialization.{ByteArraySerializer, Serializer} import scala.collection.Map import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ /** * Utility functions to help with testing @@ -131,6 +133,7 @@ object TestUtils extends Logging { /** * Create a kafka server instance with appropriate test settings * USING THIS IS A SIGN YOU ARE NOT WRITING A REAL UNIT TEST + * * @param config The configuration of the server */ def createServer(config: KafkaConfig, time: Time = SystemTime): KafkaServer = { @@ -141,7 +144,7 @@ object TestUtils extends Logging { /** * Create a test config for the provided parameters. - * + * * Note that if `interBrokerSecurityProtocol` is defined, the listener for the `SecurityProtocol` will be enabled. */ def createBrokerConfigs(numConfigs: Int, @@ -281,6 +284,7 @@ object TestUtils extends Logging { /** * Wrap the message in a message set + * * @param payload The bytes of the message */ def singleMessageSet(payload: Array[Byte], @@ -291,6 +295,7 @@ object TestUtils extends Logging { /** * Generate an array of random bytes + * * @param numBytes The size of the array */ def randomBytes(numBytes: Int): Array[Byte] = { @@ -301,6 +306,7 @@ object TestUtils extends Logging { /** * Generate a random string of letters and digits of the given length + * * @param len The length of the string * @return The random string */ @@ -679,6 +685,7 @@ object TestUtils extends Logging { * If neither oldLeaderOpt nor newLeaderOpt is defined, wait until the leader of a partition is elected. * If oldLeaderOpt is defined, it waits until the new leader is different from the old leader. * If newLeaderOpt is defined, it waits until the new leader becomes the expected new leader. + * * @return The new leader or assertion failure if timeout is reached. */ def waitUntilLeaderIsElectedOrChanged(zkUtils: ZkUtils, topic: String, partition: Int, timeoutMs: Long = 5000L, @@ -786,6 +793,7 @@ object TestUtils extends Logging { /** * Wait until a valid leader is propagated to the metadata cache in each broker. * It assumes that the leader propagated to each broker is the same. + * * @param servers The list of servers that the metadata should reach to * @param topic The topic name * @param partition The partition Id @@ -812,7 +820,7 @@ object TestUtils extends Logging { } def waitUntilLeaderIsKnown(servers: Seq[KafkaServer], topic: String, partition: Int, timeout: Long = 5000L): Unit = { - TestUtils.waitUntilTrue(() => + TestUtils.waitUntilTrue(() => servers.exists { server => server.replicaManager.getPartition(topic, partition).exists(_.leaderReplicaIfLocal().isDefined) }, @@ -968,12 +976,11 @@ object TestUtils extends Logging { /** * Consume all messages (or a specific number of messages) + * * @param topicMessageStreams the Topic Message Streams * @param nMessagesPerThread an optional field to specify the exact number of messages to be returned. * ConsumerTimeoutException will be thrown if there are no messages to be consumed. * If not specified, then all available messages will be consumed, and no exception is thrown. - * - * * @return the list of messages consumed. */ def getMessages(topicMessageStreams: Map[String, List[KafkaStream[String, String]]], @@ -1033,6 +1040,7 @@ object TestUtils extends Logging { /** * Translate the given buffer into a string + * * @param buffer The buffer to translate * @param encoding The encoding to use in translating bytes to characters */ @@ -1075,6 +1083,46 @@ object TestUtils extends Logging { s"expected acls $expected but got ${authorizer.getAcls(resource)}", waitTime = 10000) } + /** + * To use this you pass in a sequence of functions that are your arrange/act/assert test on the SUT. + * They all run at the same time in the assertConcurrent method; the chances of triggering a multithreading code error, + * and thereby failing some assertion are greatly increased. + */ + def assertConcurrent(message: String, functions: Seq[() => Any], timeoutMs: Int) { + + def failWithTimeout() { + fail(s"$message. Timed out, the concurrent functions took more than $timeoutMs milliseconds") + } + + val numThreads = functions.size + val threadPool = Executors.newFixedThreadPool(numThreads) + val exceptions = ArrayBuffer[Throwable]() + try { + val runnables = functions.map { function => + new Callable[Unit] { + override def call(): Unit = function() + } + }.asJava + val futures = threadPool.invokeAll(runnables, timeoutMs, TimeUnit.MILLISECONDS).asScala + futures.foreach { future => + if (future.isCancelled) + failWithTimeout() + else + try future.get() + catch { case e: Exception => + exceptions += e + } + } + } catch { + case ie: InterruptedException => failWithTimeout() + case e => exceptions += e + } finally { + threadPool.shutdownNow() + } + assertTrue(s"$message failed with exception(s) $exceptions", exceptions.isEmpty) + + } + } class IntEncoder(props: VerifiableProperties = null) extends Encoder[Int] {
