mirror of https://github.com/apache/kafka.git
KAFKA-99 Enforce a max request size in socket server to avoid running out of memory with very large requests.
git-svn-id: https://svn.apache.org/repos/asf/incubator/kafka/trunk@1159837 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
4f56e44100
commit
4a688c5d6c
|
@ -56,8 +56,10 @@ private[kafka] class BoundedByteBufferReceive(val maxSize: Int) extends Receive
|
|||
if(contentBuffer == null && !sizeBuffer.hasRemaining) {
|
||||
sizeBuffer.rewind()
|
||||
val size = sizeBuffer.getInt()
|
||||
if(size <= 0 || size > maxSize)
|
||||
throw new InvalidRequestException(size + " is not a valid message size")
|
||||
if(size <= 0)
|
||||
throw new InvalidRequestException("%d is not a valid request size.".format(size))
|
||||
if(size > maxSize)
|
||||
throw new InvalidRequestException("Request of length %d is not valid, it is larger than the maximum size of %d bytes.".format(size, maxSize))
|
||||
contentBuffer = byteBufferAllocate(size)
|
||||
}
|
||||
// if we have a buffer read some stuff into it
|
||||
|
|
|
@ -34,10 +34,11 @@ import kafka.api.RequestKeys
|
|||
* 1 Acceptor thread that handles new connections
|
||||
* N Processor threads that each have their own selectors and handle all requests from their connections synchronously
|
||||
*/
|
||||
private[kafka] class SocketServer(val port: Int,
|
||||
class SocketServer(val port: Int,
|
||||
val numProcessorThreads: Int,
|
||||
monitoringPeriodSecs: Int,
|
||||
private val handlerFactory: Handler.HandlerMapping) {
|
||||
private val handlerFactory: Handler.HandlerMapping,
|
||||
val maxRequestSize: Int = Int.MaxValue) {
|
||||
|
||||
private val logger = Logger.getLogger(classOf[SocketServer])
|
||||
private val time = SystemTime
|
||||
|
@ -50,7 +51,7 @@ private[kafka] class SocketServer(val port: Int,
|
|||
*/
|
||||
def startup() {
|
||||
for(i <- 0 until numProcessorThreads) {
|
||||
processors(i) = new Processor(handlerFactory, time, stats)
|
||||
processors(i) = new Processor(handlerFactory, time, stats, maxRequestSize)
|
||||
Utils.newThread("kafka-processor-" + i, processors(i), false).start()
|
||||
}
|
||||
Utils.newThread("kafka-acceptor", acceptor, false).start()
|
||||
|
@ -180,7 +181,8 @@ private[kafka] class Acceptor(val port: Int, private val processors: Array[Proce
|
|||
*/
|
||||
private[kafka] class Processor(val handlerMapping: Handler.HandlerMapping,
|
||||
val time: Time,
|
||||
val stats: SocketServerStats) extends AbstractServerThread {
|
||||
val stats: SocketServerStats,
|
||||
val maxRequestSize: Int) extends AbstractServerThread {
|
||||
|
||||
private val newConnections = new ConcurrentLinkedQueue[SocketChannel]();
|
||||
private val requestLogger = Logger.getLogger("kafka.request.logger")
|
||||
|
@ -211,11 +213,14 @@ private[kafka] class Processor(val handlerMapping: Handler.HandlerMapping,
|
|||
throw new IllegalStateException("Unrecognized key state for processor thread.")
|
||||
} catch {
|
||||
case e: EOFException => {
|
||||
logger.info("Closing socket for " + channelFor(key).socket.getInetAddress + ".")
|
||||
logger.info("Closing socket connection to %s.".format(channelFor(key).socket.getInetAddress))
|
||||
close(key)
|
||||
}
|
||||
case e: InvalidRequestException => {
|
||||
logger.info("Closing socket connection to %s due to invalid request: %s".format(channelFor(key).socket.getInetAddress, e.getMessage))
|
||||
close(key)
|
||||
} case e: Throwable => {
|
||||
logger.info("Closing socket for " + channelFor(key).socket.getInetAddress + " because of error")
|
||||
logger.error(e, e)
|
||||
logger.error("Closing socket for " + channelFor(key).socket.getInetAddress + " because of error", e)
|
||||
close(key)
|
||||
}
|
||||
}
|
||||
|
@ -293,7 +298,7 @@ private[kafka] class Processor(val handlerMapping: Handler.HandlerMapping,
|
|||
val socketChannel = channelFor(key)
|
||||
var request = key.attachment.asInstanceOf[Receive]
|
||||
if(key.attachment == null) {
|
||||
request = new BoundedByteBufferReceive()
|
||||
request = new BoundedByteBufferReceive(maxRequestSize)
|
||||
key.attach(request)
|
||||
}
|
||||
val read = request.readFrom(socketChannel)
|
||||
|
|
|
@ -68,7 +68,8 @@ class KafkaServer(val config: KafkaConfig) {
|
|||
socketServer = new SocketServer(config.port,
|
||||
config.numThreads,
|
||||
config.monitoringPeriodSecs,
|
||||
handlers.handlerFor)
|
||||
handlers.handlerFor,
|
||||
config.maxSocketRequestSize)
|
||||
Utils.swallow(logger.warn, Utils.registerMBean(socketServer.stats, statsMBeanName))
|
||||
socketServer.startup
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package kafka.network;
|
||||
|
||||
import java.net._
|
||||
import java.io._
|
||||
import java.nio._
|
||||
import java.nio.channels._
|
||||
import org.junit._
|
||||
import junit.framework.Assert._
|
||||
import org.scalatest.junit.JUnitSuite
|
||||
import kafka.utils.TestUtils
|
||||
import kafka.network._
|
||||
import java.util.Random
|
||||
import org.apache.log4j._
|
||||
|
||||
class SocketServerTest extends JUnitSuite {
|
||||
|
||||
Logger.getLogger("kafka").setLevel(Level.INFO)
|
||||
|
||||
def echo(receive: Receive): Option[Send] = {
|
||||
val id = receive.buffer.getShort
|
||||
Some(new BoundedByteBufferSend(receive.buffer.slice))
|
||||
}
|
||||
|
||||
val server = new SocketServer(port = TestUtils.choosePort,
|
||||
numProcessorThreads = 1,
|
||||
monitoringPeriodSecs = 30,
|
||||
handlerFactory = (requestId: Short, receive: Receive) => echo,
|
||||
maxRequestSize = 50)
|
||||
server.startup()
|
||||
|
||||
def sendRequest(id: Short, request: Array[Byte]): Array[Byte] = {
|
||||
val socket = new Socket("localhost", server.port)
|
||||
val outgoing = new DataOutputStream(socket.getOutputStream)
|
||||
outgoing.writeInt(request.length + 2)
|
||||
outgoing.writeShort(id)
|
||||
outgoing.write(request)
|
||||
outgoing.flush()
|
||||
val incoming = new DataInputStream(socket.getInputStream)
|
||||
val len = incoming.readInt()
|
||||
val response = new Array[Byte](len)
|
||||
incoming.readFully(response)
|
||||
socket.close()
|
||||
response
|
||||
}
|
||||
|
||||
@After
|
||||
def cleanup() {
|
||||
server.shutdown()
|
||||
}
|
||||
|
||||
@Test
|
||||
def simpleRequest() {
|
||||
val response = new String(sendRequest(0, "hello".getBytes))
|
||||
|
||||
}
|
||||
|
||||
@Test(expected=classOf[EOFException])
|
||||
def tooBigRequestIsRejected() {
|
||||
val tooManyBytes = new Array[Byte](server.maxRequestSize + 1)
|
||||
new Random().nextBytes(tooManyBytes)
|
||||
sendRequest(0, tooManyBytes)
|
||||
}
|
||||
|
||||
}
|
|
@ -88,6 +88,7 @@ object TestUtils {
|
|||
|
||||
/**
|
||||
* Create a kafka server instance with appropriate test settings
|
||||
* USING THIS IS A SIGN YOU ARE NOT WRITING A REAL UNIT TEST
|
||||
* @param config The configuration of the server
|
||||
*/
|
||||
def createServer(config: KafkaConfig): KafkaServer = {
|
||||
|
|
Loading…
Reference in New Issue