KAFKA-99 Enforce a max request size in socket server to avoid running out of memory with very large requests.

git-svn-id: https://svn.apache.org/repos/asf/incubator/kafka/trunk@1159837 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Edward Jay Kreps 2011-08-20 04:08:06 +00:00
parent 4f56e44100
commit 4a688c5d6c
5 changed files with 102 additions and 12 deletions

View File

@ -56,8 +56,10 @@ private[kafka] class BoundedByteBufferReceive(val maxSize: Int) extends Receive
if(contentBuffer == null && !sizeBuffer.hasRemaining) {
sizeBuffer.rewind()
val size = sizeBuffer.getInt()
if(size <= 0 || size > maxSize)
throw new InvalidRequestException(size + " is not a valid message size")
if(size <= 0)
throw new InvalidRequestException("%d is not a valid request size.".format(size))
if(size > maxSize)
throw new InvalidRequestException("Request of length %d is not valid, it is larger than the maximum size of %d bytes.".format(size, maxSize))
contentBuffer = byteBufferAllocate(size)
}
// if we have a buffer read some stuff into it

View File

@ -34,10 +34,11 @@ import kafka.api.RequestKeys
* 1 Acceptor thread that handles new connections
* N Processor threads that each have their own selectors and handle all requests from their connections synchronously
*/
private[kafka] class SocketServer(val port: Int,
class SocketServer(val port: Int,
val numProcessorThreads: Int,
monitoringPeriodSecs: Int,
private val handlerFactory: Handler.HandlerMapping) {
private val handlerFactory: Handler.HandlerMapping,
val maxRequestSize: Int = Int.MaxValue) {
private val logger = Logger.getLogger(classOf[SocketServer])
private val time = SystemTime
@ -50,7 +51,7 @@ private[kafka] class SocketServer(val port: Int,
*/
def startup() {
for(i <- 0 until numProcessorThreads) {
processors(i) = new Processor(handlerFactory, time, stats)
processors(i) = new Processor(handlerFactory, time, stats, maxRequestSize)
Utils.newThread("kafka-processor-" + i, processors(i), false).start()
}
Utils.newThread("kafka-acceptor", acceptor, false).start()
@ -180,7 +181,8 @@ private[kafka] class Acceptor(val port: Int, private val processors: Array[Proce
*/
private[kafka] class Processor(val handlerMapping: Handler.HandlerMapping,
val time: Time,
val stats: SocketServerStats) extends AbstractServerThread {
val stats: SocketServerStats,
val maxRequestSize: Int) extends AbstractServerThread {
private val newConnections = new ConcurrentLinkedQueue[SocketChannel]();
private val requestLogger = Logger.getLogger("kafka.request.logger")
@ -211,11 +213,14 @@ private[kafka] class Processor(val handlerMapping: Handler.HandlerMapping,
throw new IllegalStateException("Unrecognized key state for processor thread.")
} catch {
case e: EOFException => {
logger.info("Closing socket for " + channelFor(key).socket.getInetAddress + ".")
logger.info("Closing socket connection to %s.".format(channelFor(key).socket.getInetAddress))
close(key)
}
case e: InvalidRequestException => {
logger.info("Closing socket connection to %s due to invalid request: %s".format(channelFor(key).socket.getInetAddress, e.getMessage))
close(key)
} case e: Throwable => {
logger.info("Closing socket for " + channelFor(key).socket.getInetAddress + " because of error")
logger.error(e, e)
logger.error("Closing socket for " + channelFor(key).socket.getInetAddress + " because of error", e)
close(key)
}
}
@ -293,7 +298,7 @@ private[kafka] class Processor(val handlerMapping: Handler.HandlerMapping,
val socketChannel = channelFor(key)
var request = key.attachment.asInstanceOf[Receive]
if(key.attachment == null) {
request = new BoundedByteBufferReceive()
request = new BoundedByteBufferReceive(maxRequestSize)
key.attach(request)
}
val read = request.readFrom(socketChannel)

View File

@ -68,7 +68,8 @@ class KafkaServer(val config: KafkaConfig) {
socketServer = new SocketServer(config.port,
config.numThreads,
config.monitoringPeriodSecs,
handlers.handlerFor)
handlers.handlerFor,
config.maxSocketRequestSize)
Utils.swallow(logger.warn, Utils.registerMBean(socketServer.stats, statsMBeanName))
socketServer.startup
/**

View File

@ -0,0 +1,81 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.network;
import java.net._
import java.io._
import java.nio._
import java.nio.channels._
import org.junit._
import junit.framework.Assert._
import org.scalatest.junit.JUnitSuite
import kafka.utils.TestUtils
import kafka.network._
import java.util.Random
import org.apache.log4j._
class SocketServerTest extends JUnitSuite {
Logger.getLogger("kafka").setLevel(Level.INFO)
def echo(receive: Receive): Option[Send] = {
val id = receive.buffer.getShort
Some(new BoundedByteBufferSend(receive.buffer.slice))
}
val server = new SocketServer(port = TestUtils.choosePort,
numProcessorThreads = 1,
monitoringPeriodSecs = 30,
handlerFactory = (requestId: Short, receive: Receive) => echo,
maxRequestSize = 50)
server.startup()
def sendRequest(id: Short, request: Array[Byte]): Array[Byte] = {
val socket = new Socket("localhost", server.port)
val outgoing = new DataOutputStream(socket.getOutputStream)
outgoing.writeInt(request.length + 2)
outgoing.writeShort(id)
outgoing.write(request)
outgoing.flush()
val incoming = new DataInputStream(socket.getInputStream)
val len = incoming.readInt()
val response = new Array[Byte](len)
incoming.readFully(response)
socket.close()
response
}
@After
def cleanup() {
server.shutdown()
}
@Test
def simpleRequest() {
val response = new String(sendRequest(0, "hello".getBytes))
}
@Test(expected=classOf[EOFException])
def tooBigRequestIsRejected() {
val tooManyBytes = new Array[Byte](server.maxRequestSize + 1)
new Random().nextBytes(tooManyBytes)
sendRequest(0, tooManyBytes)
}
}

View File

@ -88,6 +88,7 @@ object TestUtils {
/**
* Create a kafka server instance with appropriate test settings
* USING THIS IS A SIGN YOU ARE NOT WRITING A REAL UNIT TEST
* @param config The configuration of the server
*/
def createServer(config: KafkaConfig): KafkaServer = {