From 73c646c442fd17e3f9919eb2fd50fdac75e32917 Mon Sep 17 00:00:00 2001 From: parafiend Date: Fri, 9 Feb 2018 04:59:18 -0800 Subject: [PATCH] KAFKA-6529: Stop file descriptor leak when client disconnects with staged receives (#4517) If an exception is encountered while sending data to a client connection, that connection is disconnected. If there are staged receives for that connection, they are tracked to process those records. However, if the exception was encountered during processing a `RequestChannel.Request`, the `KafkaChannel` for that connection is muted and won't be processed. Disable processing of outstanding staged receives if a send fails. This stops the leak of the memory for pending requests and the file descriptor of the TCP socket. Test that a channel is closed when an exception is raised while writing to a socket that has been closed by the client. Since sending a response requires acks != 0, allow specifying the required acks for test requests in SocketServerTest.scala. Author: Graham Campbell Reviewers: Jason Gustafson , Rajini Sivaram , Ismael Juma , Ted Yu --- checkstyle/suppressions.xml | 2 +- .../apache/kafka/common/network/Selector.java | 44 +++++++++--- .../scala/kafka/network/SocketServer.scala | 8 +++ .../unit/kafka/network/SocketServerTest.scala | 68 +++++++++++++++---- 4 files changed, 99 insertions(+), 23 deletions(-) diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml index 98256710004..c35c77ce03c 100644 --- a/checkstyle/suppressions.xml +++ b/checkstyle/suppressions.xml @@ -54,7 +54,7 @@ files="AbstractRequest.java"/> + files="(BufferPool|MetricName|Node|ConfigDef|SslTransportLayer|MetadataResponse|KerberosLogin|SslTransportLayer|Selector|Sender|Serdes|PluginUtils).java"/> deque = stagedReceives.get(channel); return deque == null ? 0 : deque.size(); } diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 6db70cfb5d3..af942314ec6 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -605,6 +605,14 @@ private[kafka] class Processor(val id: Int, private[network] def channel(connectionId: String): Option[KafkaChannel] = Option(selector.channel(connectionId)) + /* For test usage */ + private[network] def openOrClosingChannel(connectionId: String): Option[KafkaChannel] = + channel(connectionId).orElse(Option(selector.closingChannel(connectionId))) + + // Visible for testing + private[network] def numStagedReceives(connectionId: String): Int = + openOrClosingChannel(connectionId).map(c => selector.numStagedReceives(c)).getOrElse(0) + /** * Wakeup the thread for selection. */ diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index acf96e86dfe..a3897c0992b 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -65,7 +65,7 @@ class SocketServerTest extends JUnitSuite { server.startup() val sockets = new ArrayBuffer[Socket] - def sendRequest(socket: Socket, request: Array[Byte], id: Option[Short] = None) { + def sendRequest(socket: Socket, request: Array[Byte], id: Option[Short] = None, flush: Boolean = true) { val outgoing = new DataOutputStream(socket.getOutputStream) id match { case Some(id) => @@ -75,7 +75,8 @@ class SocketServerTest extends JUnitSuite { outgoing.writeInt(request.length) } outgoing.write(request) - outgoing.flush() + if (flush) + outgoing.flush() } def receiveResponse(socket: Socket): Array[Byte] = { @@ -86,10 +87,15 @@ class SocketServerTest extends JUnitSuite { response } + private def receiveRequest(channel: RequestChannel, timeout: Long = 2000L): RequestChannel.Request = { + val request = channel.receiveRequest(timeout) + assertNotNull("receiveRequest timed out", request) + request + } + /* A simple request handler that just echos back the response */ def processRequest(channel: RequestChannel) { - val request = channel.receiveRequest(2000) - assertNotNull("receiveRequest timed out", request) + val request = receiveRequest(channel) processRequest(channel, request) } @@ -115,12 +121,11 @@ class SocketServerTest extends JUnitSuite { sockets.clear() } - private def producerRequestBytes: Array[Byte] = { + private def producerRequestBytes(ack: Short = 0): Array[Byte] = { val apiKey: Short = 0 val correlationId = -1 val clientId = "" val ackTimeoutMs = 10000 - val ack = 0: Short val emptyRequest = new ProduceRequest.Builder(RecordBatch.CURRENT_MAGIC_VALUE, ack, ackTimeoutMs, new HashMap[TopicPartition, MemoryRecords]()).build() @@ -133,11 +138,30 @@ class SocketServerTest extends JUnitSuite { serializedBytes } + private def sendRequestsUntilStagedReceive(server: SocketServer, socket: Socket, requestBytes: Array[Byte]): RequestChannel.Request = { + def sendTwoRequestsReceiveOne(): RequestChannel.Request = { + sendRequest(socket, requestBytes, flush = false) + sendRequest(socket, requestBytes, flush = true) + receiveRequest(server.requestChannel) + } + val (request, hasStagedReceives) = TestUtils.computeUntilTrue(sendTwoRequestsReceiveOne()) { req => + val connectionId = req.connectionId + val hasStagedReceives = server.processor(0).numStagedReceives(connectionId) > 0 + if (!hasStagedReceives) { + processRequest(server.requestChannel, req) + processRequest(server.requestChannel) + } + hasStagedReceives + } + assertTrue(s"Receives not staged for ${org.apache.kafka.test.TestUtils.DEFAULT_MAX_WAIT_MS} ms", hasStagedReceives) + request + } + @Test def simpleRequest() { val plainSocket = connect(protocol = SecurityProtocol.PLAINTEXT) val traceSocket = connect(protocol = SecurityProtocol.TRACE) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() // Test PLAINTEXT socket sendRequest(plainSocket, serializedBytes) @@ -171,7 +195,7 @@ class SocketServerTest extends JUnitSuite { @Test def testGracefulClose() { val plainSocket = connect(protocol = SecurityProtocol.PLAINTEXT) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() for (_ <- 0 until 10) sendRequest(plainSocket, serializedBytes) @@ -236,7 +260,7 @@ class SocketServerTest extends JUnitSuite { TestUtils.waitUntilTrue(() => server.connectionCount(address) < conns.length, "Failed to decrement connection count after close") val conn2 = connect() - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() sendRequest(conn2, serializedBytes) val request = server.requestChannel.receiveRequest(2000) assertNotNull(request) @@ -255,7 +279,7 @@ class SocketServerTest extends JUnitSuite { val conns = (0 until overrideNum).map(_ => connect(overrideServer)) // it should succeed - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() sendRequest(conns.last, serializedBytes) val request = overrideServer.requestChannel.receiveRequest(2000) assertNotNull(request) @@ -341,7 +365,7 @@ class SocketServerTest extends JUnitSuite { try { overrideServer.startup() conn = connect(overrideServer) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() sendRequest(conn, serializedBytes) val channel = overrideServer.requestChannel @@ -367,6 +391,26 @@ class SocketServerTest extends JUnitSuite { } } + @Test + def testClientDisconnectionWithStagedReceivesFullyProcessed() { + val socket = connect(server) + + // Setup channel to client with staged receives so when client disconnects + // it will be stored in Selector.closingChannels + val serializedBytes = producerRequestBytes(1) + val request = sendRequestsUntilStagedReceive(server, socket, serializedBytes) + val connectionId = request.connectionId + + // Set SoLinger to 0 to force a hard disconnect via TCP RST + socket.setSoLinger(true, 0) + socket.close() + + // Complete request with socket exception so that the channel is removed from Selector.closingChannels + processRequest(server.requestChannel, request) + TestUtils.waitUntilTrue(() => server.processor(0).openOrClosingChannel(connectionId).isEmpty, + "Channel not closed after failed send") + } + /* * Test that we update request metrics if the channel has been removed from the selector when the broker calls * `selector.send` (selector closes old connections, for example). @@ -381,7 +425,7 @@ class SocketServerTest extends JUnitSuite { try { overrideServer.startup() conn = connect(overrideServer) - val serializedBytes = producerRequestBytes + val serializedBytes = producerRequestBytes() sendRequest(conn, serializedBytes) val channel = overrideServer.requestChannel val request = channel.receiveRequest(2000)