diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index e91c240415c..90769a376fb 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -24,7 +24,6 @@ import java.nio.channels.{Selector => NSelector, _} import java.util import java.util.concurrent._ import java.util.concurrent.atomic._ - import kafka.cluster.{BrokerEndPoint, EndPoint} import kafka.metrics.KafkaMetricsGroup import kafka.network.ConnectionQuotas._ @@ -48,6 +47,7 @@ import org.apache.kafka.common.requests.{ApiVersionsRequest, RequestContext, Req import org.apache.kafka.common.security.auth.SecurityProtocol import org.apache.kafka.common.utils.{KafkaThread, LogContext, Time, Utils} import org.apache.kafka.common.{Endpoint, KafkaException, MetricName, Reconfigurable} +import org.apache.kafka.server.util.FutureUtils import org.slf4j.event.Level import scala.collection._ @@ -189,10 +189,14 @@ class SocketServer(val config: KafkaConfig, * processor corresponding to the [[EndPoint]]. Any endpoint * that does not appear in this map will be started once all * authorizerFutures are complete. + * + * @return A future which is completed when all of the acceptor threads have + * successfully started. If any of them do not start, the future will + * be completed with an exception. */ def enableRequestProcessing( authorizerFutures: Map[Endpoint, CompletableFuture[Void]] - ): Unit = this.synchronized { + ): CompletableFuture[Void] = this.synchronized { if (stopped) { throw new RuntimeException("Can't enable request processing: SocketServer is stopped.") } @@ -200,19 +204,36 @@ class SocketServer(val config: KafkaConfig, def chainAcceptorFuture(acceptor: Acceptor): Unit = { // Because of ephemeral ports, we need to match acceptors to futures by looking at // the listener name, rather than the endpoint object. - authorizerFutures.find { + val authorizerFuture = authorizerFutures.find { case (endpoint, _) => acceptor.endPoint.listenerName.value().equals(endpoint.listenerName().get()) } match { - case None => chainFuture(allAuthorizerFuturesComplete, acceptor.startFuture) - case Some((_, future)) => chainFuture(future, acceptor.startFuture) + case None => allAuthorizerFuturesComplete + case Some((_, future)) => future } + authorizerFuture.whenComplete((_, e) => { + if (e != null) { + // If the authorizer failed to start, fail the acceptor's startedFuture. + acceptor.startedFuture.completeExceptionally(e) + } else { + // Once the authorizer has started, attempt to start the associated acceptor. The Acceptor.start() + // function will complete the acceptor started future (either successfully or not) + acceptor.start() + } + }) } info("Enabling request processing.") controlPlaneAcceptorOpt.foreach(chainAcceptorFuture) dataPlaneAcceptors.values().forEach(chainAcceptorFuture) - chainFuture(CompletableFuture.allOf(authorizerFutures.values.toArray: _*), + FutureUtils.chainFuture(CompletableFuture.allOf(authorizerFutures.values.toArray: _*), allAuthorizerFuturesComplete) + + // Construct a future that will be completed when all Acceptors have been successfully started. + // Alternately, if any of them fail to start, this future will be completed exceptionally. + val allAcceptors = dataPlaneAcceptors.values().asScala.toSeq ++ controlPlaneAcceptorOpt + val enableFuture = new CompletableFuture[Void] + FutureUtils.chainFuture(CompletableFuture.allOf(allAcceptors.map(_.startedFuture).toArray: _*), enableFuture) + enableFuture } def createDataPlaneAcceptorAndProcessors(endpoint: EndPoint): Unit = synchronized { @@ -289,13 +310,13 @@ class SocketServer(val config: KafkaConfig, try { val acceptor = dataPlaneAcceptors.get(endpoints(listenerName)) if (acceptor != null) { - acceptor.serverChannel.socket.getLocalPort + acceptor.localPort } else { - controlPlaneAcceptorOpt.map(_.serverChannel.socket().getLocalPort).getOrElse(throw new KafkaException("Could not find listenerName : " + listenerName + " in data-plane or control-plane")) + controlPlaneAcceptorOpt.map(_.localPort).getOrElse(throw new KafkaException("Could not find listenerName : " + listenerName + " in data-plane or control-plane")) } } catch { case e: Exception => - throw new KafkaException("Tried to check server's port before server was started or checked for port of non-existing protocol", e) + throw new KafkaException("Tried to check for port of non-existing protocol", e) } } @@ -312,7 +333,13 @@ class SocketServer(val config: KafkaConfig, val acceptor = dataPlaneAcceptors.get(endpoint) // There is no authorizer future for this new listener endpoint. So start the // listener once all authorizer futures are complete. - chainFuture(allAuthorizerFuturesComplete, acceptor.startFuture) + allAuthorizerFuturesComplete.whenComplete((_, e) => { + if (e != null) { + acceptor.startedFuture.completeExceptionally(e) + } else { + acceptor.start() + } + }) } } @@ -388,15 +415,6 @@ object SocketServer { CoreUtils.swallow(channel.socket().close(), logging, Level.ERROR) CoreUtils.swallow(channel.close(), logging, Level.ERROR) } - - def chainFuture(sourceFuture: CompletableFuture[Void], - destinationFuture: CompletableFuture[Void]): Unit = { - sourceFuture.whenComplete((_, t) => if (t != null) { - destinationFuture.completeExceptionally(t) - } else { - destinationFuture.complete(null) - }) - } } object DataPlaneAcceptor { @@ -573,7 +591,21 @@ private[kafka] abstract class Acceptor(val socketServer: SocketServer, private val listenBacklogSize = config.socketListenBacklogSize private val nioSelector = NSelector.open() - private[network] val serverChannel = openServerSocket(endPoint.host, endPoint.port, listenBacklogSize) + + // If the port is configured as 0, we are using a wildcard port, so we need to open the socket + // before we can find out what port we have. If it is set to a nonzero value, defer opening + // the socket until we start the Acceptor. The reason for deferring the socket opening is so + // that systems which assume that the socket being open indicates readiness are not confused. + private[network] var serverChannel: ServerSocketChannel = _ + private[network] val localPort: Int = if (endPoint.port != 0) { + endPoint.port + } else { + serverChannel = openServerSocket(endPoint.host, endPoint.port, listenBacklogSize) + val newPort = serverChannel.socket().getLocalPort() + info(s"Opened wildcard endpoint ${endPoint.host}:${newPort}") + newPort + } + private[network] val processors = new ArrayBuffer[Processor]() // Build the metric name explicitly in order to keep the existing name for compatibility private val blockedPercentMeterMetricName = explicitMetricName( @@ -585,23 +617,36 @@ private[kafka] abstract class Acceptor(val socketServer: SocketServer, private var currentProcessorIndex = 0 private[network] val throttledSockets = new mutable.PriorityQueue[DelayedCloseSocket]() private var started = false - private[network] val startFuture = new CompletableFuture[Void]() + private[network] val startedFuture = new CompletableFuture[Void]() val thread = KafkaThread.nonDaemon( s"${threadPrefix()}-kafka-socket-acceptor-${endPoint.listenerName}-${endPoint.securityProtocol}-${endPoint.port}", this) - startFuture.thenRun(() => synchronized { - if (!shouldRun.get()) { - debug(s"Ignoring start future for ${endPoint.listenerName} since the acceptor has already been shut down.") - } else { + def start(): Unit = synchronized { + try { + if (!shouldRun.get()) { + throw new ClosedChannelException() + } + if (serverChannel == null) { + serverChannel = openServerSocket(endPoint.host, endPoint.port, listenBacklogSize) + debug(s"Opened endpoint ${endPoint.host}:${endPoint.port}") + } debug(s"Starting processors for listener ${endPoint.listenerName}") - started = true processors.foreach(_.start()) debug(s"Starting acceptor thread for listener ${endPoint.listenerName}") thread.start() + startedFuture.complete(null) + started = true + } catch { + case e: ClosedChannelException => + debug(s"Refusing to start acceptor for ${endPoint.listenerName} since the acceptor has already been shut down.") + startedFuture.completeExceptionally(e) + case t: Throwable => + error(s"Unable to start acceptor for ${endPoint.listenerName}", t) + startedFuture.completeExceptionally(new RuntimeException(s"Unable to start acceptor for ${endPoint.listenerName}", t)) } - }) + } private[network] case class DelayedCloseSocket(socket: SocketChannel, endThrottleTimeMs: Long) extends Ordered[DelayedCloseSocket] { override def compare(that: DelayedCloseSocket): Int = endThrottleTimeMs compare that.endThrottleTimeMs diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala b/core/src/main/scala/kafka/server/BrokerServer.scala index e93741e10f5..37649bf93c4 100644 --- a/core/src/main/scala/kafka/server/BrokerServer.scala +++ b/core/src/main/scala/kafka/server/BrokerServer.scala @@ -476,7 +476,7 @@ class BrokerServer( // Enable inbound TCP connections. Each endpoint will be started only once its matching // authorizer future is completed. - socketServer.enableRequestProcessing(authorizerFutures) + val socketServerFuture = socketServer.enableRequestProcessing(authorizerFutures) // If we are using a ClusterMetadataAuthorizer which stores its ACLs in the metadata log, // notify it that the loading process is complete. @@ -495,6 +495,10 @@ class BrokerServer( FutureUtils.waitWithLogging(logger.underlying, "all of the authorizer futures to be completed", CompletableFuture.allOf(authorizerFutures.values.toSeq: _*), startupDeadline, time) + // Wait for all the SocketServer ports to be open, and the Acceptors to be started. + FutureUtils.waitWithLogging(logger.underlying, "all of the SocketServer Acceptors to be started", + socketServerFuture, startupDeadline, time) + maybeChangeStatus(STARTING, STARTED) } catch { case e: Throwable => diff --git a/core/src/main/scala/kafka/server/ControllerServer.scala b/core/src/main/scala/kafka/server/ControllerServer.scala index 81a831caaff..3a9321720ee 100644 --- a/core/src/main/scala/kafka/server/ControllerServer.scala +++ b/core/src/main/scala/kafka/server/ControllerServer.scala @@ -296,11 +296,16 @@ class ControllerServer( * metadata log. See @link{QuorumController#maybeCompleteAuthorizerInitialLoad} * and KIP-801 for details. */ - socketServer.enableRequestProcessing(authorizerFutures) + val socketServerFuture = socketServer.enableRequestProcessing(authorizerFutures) // Block here until all the authorizer futures are complete FutureUtils.waitWithLogging(logger.underlying, "all of the authorizer futures to be completed", CompletableFuture.allOf(authorizerFutures.values.toSeq: _*), startupDeadline, time) + + // Wait for all the SocketServer ports to be open, and the Acceptors to be started. + FutureUtils.waitWithLogging(logger.underlying, "all of the SocketServer Acceptors to be started", + socketServerFuture, startupDeadline, time) + } catch { case e: Throwable => maybeChangeStatus(STARTING, STARTED) diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index a65ccb8dbe5..459bcbce5b8 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -23,11 +23,11 @@ import java.nio.ByteBuffer import java.nio.channels.{SelectionKey, SocketChannel} import java.nio.charset.StandardCharsets import java.util -import java.util.concurrent.{CompletableFuture, ConcurrentLinkedQueue, Executors, TimeUnit} +import java.util.concurrent.{CompletableFuture, ConcurrentLinkedQueue, ExecutionException, Executors, TimeUnit} import java.util.{Properties, Random} - import com.fasterxml.jackson.databind.node.{JsonNodeFactory, ObjectNode, TextNode} import com.yammer.metrics.core.{Gauge, Meter} + import javax.net.ssl._ import kafka.cluster.EndPoint import kafka.security.CredentialProvider @@ -50,8 +50,8 @@ import org.apache.kafka.test.{TestSslUtils, TestUtils => JTestUtils} import org.apache.log4j.Level import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api._ -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.server.metrics.KafkaYammerMetrics import scala.collection.mutable @@ -79,7 +79,7 @@ class SocketServerTest { private val apiVersionManager = new SimpleApiVersionManager(ListenerType.ZK_BROKER) val server = new SocketServer(config, metrics, Time.SYSTEM, credentialProvider, apiVersionManager) - server.enableRequestProcessing(Map.empty) + server.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) val sockets = new ArrayBuffer[Socket] private val kafkaLogger = org.apache.log4j.LogManager.getLogger("kafka") @@ -162,7 +162,18 @@ class SocketServerTest { listenerName: ListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), localAddr: InetAddress = null, port: Int = 0): Socket = { - val socket = new Socket("localhost", s.boundPort(listenerName), localAddr, port) + val boundPort = try { + s.boundPort(listenerName) + } catch { + case e: Throwable => throw new RuntimeException("Unable to find bound port for listener " + + s"${listenerName}", e) + } + val socket = try { + new Socket("localhost", boundPort, localAddr, port) + } catch { + case e: Throwable => throw new RuntimeException(s"Unable to connect to remote port ${boundPort} " + + s"with local port ${port} on listener ${listenerName}", e) + } sockets += socket socket } @@ -337,13 +348,14 @@ class SocketServerTest { val futures = Map( externalEndpoint -> externalReadyFuture, controlPlaneEndpoint -> CompletableFuture.completedFuture[Void](null)) - testableServer.enableRequestProcessing(futures) + val requestProcessingFuture = testableServer.enableRequestProcessing(futures) TestUtils.waitUntilTrue(() => controlPlaneListenerStarted(), "Control plane listener not started") assertFalse(listenerStarted(config.interBrokerListenerName)) assertFalse(listenerStarted(externalListener)) externalReadyFuture.complete(null) TestUtils.waitUntilTrue(() => listenerStarted(config.interBrokerListenerName), "Inter-broker listener not started") TestUtils.waitUntilTrue(() => listenerStarted(externalListener), "External listener not started") + requestProcessingFuture.get(1, TimeUnit.MINUTES) } finally { shutdownServerAndMetrics(testableServer) } @@ -361,6 +373,7 @@ class SocketServerTest { val config = KafkaConfig.fromProps(testProps) val connectionQueueSize = 1 val testableServer = new TestableSocketServer(config, connectionQueueSize) + testableServer.enableRequestProcessing(Map()).get(1, TimeUnit.MINUTES) val socket1 = connect(testableServer, new ListenerName("EXTERNAL"), localAddr = InetAddress.getLocalHost) sendRequest(socket1, producerRequestBytes()) @@ -466,7 +479,7 @@ class SocketServerTest { time, credentialProvider, apiVersionManager) try { - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) val serializedBytes = producerRequestBytes() // Connection with no outstanding requests @@ -534,7 +547,7 @@ class SocketServerTest { } try { - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) overrideServer.testableProcessor.setConnectionId(overrideConnectionId) val socket1 = connectAndWaitForConnectionRegister() TestUtils.waitUntilTrue(() => connectionCount == 1 && openChannel.isDefined, "Failed to create channel") @@ -803,7 +816,7 @@ class SocketServerTest { val server = new SocketServer(KafkaConfig.fromProps(newProps), new Metrics(), Time.SYSTEM, credentialProvider, apiVersionManager) try { - server.enableRequestProcessing(Map.empty) + server.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) // make the maximum allowable number of connections val conns = (0 until 5).map(_ => connect(server)) // now try one more (should fail) @@ -842,7 +855,7 @@ class SocketServerTest { val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, Time.SYSTEM, credentialProvider, apiVersionManager) try { - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) // make the maximum allowable number of connections val conns = (0 until overrideNum).map(_ => connect(overrideServer)) @@ -882,7 +895,7 @@ class SocketServerTest { } try { - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) val conn = connect(overrideServer) conn.setSoTimeout(3000) assertEquals(-1, conn.getInputStream.read()) @@ -905,7 +918,7 @@ class SocketServerTest { // update the connection rate to 5 overrideServer.connectionQuotas.updateIpConnectionRateQuota(None, Some(connectionRate)) try { - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) // make the (maximum allowable number + 1) of connections (0 to connectionRate).map(_ => connect(overrideServer)) @@ -954,7 +967,7 @@ class SocketServerTest { val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), new Metrics(), time, credentialProvider, apiVersionManager) overrideServer.connectionQuotas.updateIpConnectionRateQuota(None, Some(connectionRate)) - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) // make the maximum allowable number of connections (0 until connectionRate).map(_ => connect(overrideServer)) // now try one more (should get throttled) @@ -977,7 +990,7 @@ class SocketServerTest { val overrideServer = new SocketServer(KafkaConfig.fromProps(sslServerProps), serverMetrics, Time.SYSTEM, credentialProvider, apiVersionManager) try { - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) val sslContext = SSLContext.getInstance(TestSslUtils.DEFAULT_TLS_PROTOCOL_FOR_TESTS) sslContext.init(null, Array(TestUtils.trustAllCerts), new java.security.SecureRandom()) val socketFactory = sslContext.getSocketFactory @@ -1036,7 +1049,7 @@ class SocketServerTest { val time = new MockTime() val overrideServer = new TestableSocketServer(KafkaConfig.fromProps(overrideProps), time = time) try { - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) val socket = connect(overrideServer, ListenerName.forSecurityProtocol(SecurityProtocol.SASL_PLAINTEXT)) val correlationId = -1 @@ -1116,7 +1129,7 @@ class SocketServerTest { val overrideServer = new TestableSocketServer(KafkaConfig.fromProps(props)) try { - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) val conn: Socket = connect(overrideServer) overrideServer.testableProcessor.closeSocketOnSendResponse(conn) val serializedBytes = producerRequestBytes() @@ -1148,7 +1161,7 @@ class SocketServerTest { val overrideServer = new TestableSocketServer(KafkaConfig.fromProps(props)) try { - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) val selector = overrideServer.testableSelector // Create a channel, send some requests and close socket. Receive one pending request after socket was closed. @@ -1176,7 +1189,7 @@ class SocketServerTest { val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, Time.SYSTEM, credentialProvider, apiVersionManager) try { - overrideServer.enableRequestProcessing(Map.empty) + overrideServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) conn = connect(overrideServer) val serializedBytes = producerRequestBytes() sendRequest(conn, serializedBytes) @@ -1557,7 +1570,7 @@ class SocketServerTest { props.put(KafkaConfig.ConnectionsMaxIdleMsProp, idleTimeMs.toString) props ++= sslServerProps val testableServer = new TestableSocketServer(time = time) - testableServer.enableRequestProcessing(Map.empty) + testableServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) assertTrue(testableServer.controlPlaneRequestChannelOpt.isEmpty) @@ -1593,7 +1606,7 @@ class SocketServerTest { val time = new MockTime() props ++= sslServerProps val testableServer = new TestableSocketServer(time = time) - testableServer.enableRequestProcessing(Map.empty) + testableServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) val proxyServer = new ProxyServer(testableServer) try { val testableSelector = testableServer.testableSelector @@ -1739,7 +1752,7 @@ class SocketServerTest { val numConnections = 5 props.put("max.connections.per.ip", numConnections.toString) val testableServer = new TestableSocketServer(KafkaConfig.fromProps(props), connectionQueueSize = 1) - testableServer.enableRequestProcessing(Map.empty) + testableServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) val testableSelector = testableServer.testableSelector val errors = new mutable.HashSet[String] @@ -1888,7 +1901,7 @@ class SocketServerTest { props ++= sslServerProps val testableServer = new TestableSocketServer(KafkaConfig.fromProps(props)) - testableServer.enableRequestProcessing(Map.empty) + testableServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) val testableSelector = testableServer.testableSelector val proxyServer = new ProxyServer(testableServer) val selectTimeoutMs = 5000 @@ -1916,6 +1929,59 @@ class SocketServerTest { } } + @Test + def testAuthorizerFailureCausesEnableRequestProcessingFailure(): Unit = { + shutdownServerAndMetrics(server) + val newServer = new SocketServer(config, metrics, Time.SYSTEM, credentialProvider, apiVersionManager) + try { + val failedFuture = new CompletableFuture[Void]() + failedFuture.completeExceptionally(new RuntimeException("authorizer startup failed")) + assertThrows(classOf[ExecutionException], () => { + newServer.enableRequestProcessing(Map(endpoint.toJava -> failedFuture)).get() + }) + } finally { + shutdownServerAndMetrics(newServer) + } + } + + @Test + def testFailedAcceptorStartupCausesEnableRequestProcessingFailure(): Unit = { + shutdownServerAndMetrics(server) + val newServer = new SocketServer(config, metrics, Time.SYSTEM, credentialProvider, apiVersionManager) + try { + newServer.dataPlaneAcceptors.values().forEach(a => a.shouldRun.set(false)) + assertThrows(classOf[ExecutionException], () => { + newServer.enableRequestProcessing(Map()).get() + }) + } finally { + shutdownServerAndMetrics(newServer) + } + } + + @Test + def testAcceptorStartOpensPortIfNeeded(): Unit = { + shutdownServerAndMetrics(server) + val newServer = new SocketServer(config, metrics, Time.SYSTEM, credentialProvider, apiVersionManager) + try { + newServer.dataPlaneAcceptors.values().forEach(a => { + a.serverChannel.close() + a.serverChannel = null + }) + val authorizerFuture = new CompletableFuture[Void]() + val enableFuture = newServer.enableRequestProcessing( + newServer.dataPlaneAcceptors.keys().asScala. + map(_.toJava).map(k => k -> authorizerFuture).toMap) + assertFalse(authorizerFuture.isDone()) + assertFalse(enableFuture.isDone()) + newServer.dataPlaneAcceptors.values().forEach(a => assertNull(a.serverChannel)) + authorizerFuture.complete(null) + enableFuture.get(1, TimeUnit.MINUTES) + newServer.dataPlaneAcceptors.values().forEach(a => assertNotNull(a.serverChannel)) + } finally { + shutdownServerAndMetrics(newServer) + } + } + private def sslServerProps: Properties = { val trustStoreFile = TestUtils.tempFile("truststore", ".jks") val sslProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, interBrokerSecurityProtocol = Some(SecurityProtocol.SSL), @@ -1930,7 +1996,7 @@ class SocketServerTest { shutdownServerAndMetrics(server) val testableServer = new TestableSocketServer(config) if (startProcessingRequests) { - testableServer.enableRequestProcessing(Map.empty) + testableServer.enableRequestProcessing(Map.empty).get(1, TimeUnit.MINUTES) } try { testWithServer(testableServer) diff --git a/core/src/test/scala/unit/kafka/server/ServerStartupTest.scala b/core/src/test/scala/unit/kafka/server/ServerStartupTest.scala index 62313498d34..210399970c7 100755 --- a/core/src/test/scala/unit/kafka/server/ServerStartupTest.scala +++ b/core/src/test/scala/unit/kafka/server/ServerStartupTest.scala @@ -18,8 +18,6 @@ package kafka.server import kafka.utils.TestUtils -import kafka.server.QuorumTestHarness -import org.apache.kafka.common.KafkaException import org.apache.kafka.metadata.BrokerState import org.apache.zookeeper.KeeperException.NodeExistsException import org.junit.jupiter.api.Assertions._ @@ -60,7 +58,7 @@ class ServerStartupTest extends QuorumTestHarness { // Create a second broker with same port val brokerId2 = 1 val props2 = TestUtils.createBrokerConfig(brokerId2, zkConnect, port = port) - assertThrows(classOf[KafkaException], () => TestUtils.createServer(KafkaConfig.fromProps(props2))) + assertThrows(classOf[IllegalArgumentException], () => TestUtils.createServer(KafkaConfig.fromProps(props2))) } @Test diff --git a/server-common/src/main/java/org/apache/kafka/server/util/FutureUtils.java b/server-common/src/main/java/org/apache/kafka/server/util/FutureUtils.java index eed0287e7b1..3383904d5ec 100644 --- a/server-common/src/main/java/org/apache/kafka/server/util/FutureUtils.java +++ b/server-common/src/main/java/org/apache/kafka/server/util/FutureUtils.java @@ -22,6 +22,7 @@ import org.slf4j.Logger; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; public class FutureUtils { @@ -66,4 +67,27 @@ public class FutureUtils { throw new RuntimeException("Received a fatal error while waiting for " + action, t); } } + + /** + * Complete a given destination future when a source future is completed. + * + * @param sourceFuture The future to trigger off of. + * @param destinationFuture The future to complete when the source future is completed. + * @param The destination future type. + */ + public static void chainFuture( + CompletableFuture sourceFuture, + CompletableFuture destinationFuture + ) { + sourceFuture.whenComplete(new BiConsumer() { + @Override + public void accept(T val, Throwable throwable) { + if (throwable != null) { + destinationFuture.completeExceptionally(throwable); + } else { + destinationFuture.complete(val); + } + } + }); + } } diff --git a/server-common/src/test/java/org/apache/kafka/server/util/FutureUtilsTest.java b/server-common/src/test/java/org/apache/kafka/server/util/FutureUtilsTest.java index 3855c1a88d2..8e3703b1104 100644 --- a/server-common/src/test/java/org/apache/kafka/server/util/FutureUtilsTest.java +++ b/server-common/src/test/java/org/apache/kafka/server/util/FutureUtilsTest.java @@ -26,12 +26,14 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -90,4 +92,31 @@ public class FutureUtilsTest { executorService.shutdown(); executorService.awaitTermination(1, TimeUnit.MINUTES); } + + @Test + public void testChainFuture() throws Throwable { + CompletableFuture sourceFuture = new CompletableFuture<>(); + CompletableFuture destinationFuture = new CompletableFuture<>(); + FutureUtils.chainFuture(sourceFuture, destinationFuture); + assertFalse(sourceFuture.isDone()); + assertFalse(destinationFuture.isDone()); + assertFalse(sourceFuture.isCancelled()); + assertFalse(destinationFuture.isCancelled()); + assertFalse(sourceFuture.isCompletedExceptionally()); + assertFalse(destinationFuture.isCompletedExceptionally()); + sourceFuture.complete(123); + assertEquals(Integer.valueOf(123), destinationFuture.get()); + } + + @Test + public void testChainFutureExceptionally() throws Throwable { + CompletableFuture sourceFuture = new CompletableFuture<>(); + CompletableFuture destinationFuture = new CompletableFuture<>(); + FutureUtils.chainFuture(sourceFuture, destinationFuture); + sourceFuture.completeExceptionally(new RuntimeException("source failed")); + Throwable cause = assertThrows(ExecutionException.class, + () -> destinationFuture.get()).getCause(); + assertEquals(RuntimeException.class, cause.getClass()); + assertEquals("source failed", cause.getMessage()); + } }