diff --git a/core/src/main/scala/kafka/network/BoundedByteBufferReceive.scala b/core/src/main/scala/kafka/network/BoundedByteBufferReceive.scala index cd586940fbc..d36944aa5dd 100644 --- a/core/src/main/scala/kafka/network/BoundedByteBufferReceive.scala +++ b/core/src/main/scala/kafka/network/BoundedByteBufferReceive.scala @@ -56,8 +56,10 @@ private[kafka] class BoundedByteBufferReceive(val maxSize: Int) extends Receive if(contentBuffer == null && !sizeBuffer.hasRemaining) { sizeBuffer.rewind() val size = sizeBuffer.getInt() - if(size <= 0 || size > maxSize) - throw new InvalidRequestException(size + " is not a valid message size") + if(size <= 0) + throw new InvalidRequestException("%d is not a valid request size.".format(size)) + if(size > maxSize) + throw new InvalidRequestException("Request of length %d is not valid, it is larger than the maximum size of %d bytes.".format(size, maxSize)) contentBuffer = byteBufferAllocate(size) } // if we have a buffer read some stuff into it diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 30c2f5df40e..86836eb5274 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -34,10 +34,11 @@ import kafka.api.RequestKeys * 1 Acceptor thread that handles new connections * N Processor threads that each have their own selectors and handle all requests from their connections synchronously */ -private[kafka] class SocketServer(val port: Int, +class SocketServer(val port: Int, val numProcessorThreads: Int, monitoringPeriodSecs: Int, - private val handlerFactory: Handler.HandlerMapping) { + private val handlerFactory: Handler.HandlerMapping, + val maxRequestSize: Int = Int.MaxValue) { private val logger = Logger.getLogger(classOf[SocketServer]) private val time = SystemTime @@ -50,7 +51,7 @@ private[kafka] class SocketServer(val port: Int, */ def startup() { for(i <- 0 until numProcessorThreads) { - processors(i) = new Processor(handlerFactory, time, stats) + processors(i) = new Processor(handlerFactory, time, stats, maxRequestSize) Utils.newThread("kafka-processor-" + i, processors(i), false).start() } Utils.newThread("kafka-acceptor", acceptor, false).start() @@ -179,8 +180,9 @@ private[kafka] class Acceptor(val port: Int, private val processors: Array[Proce * each of which has its own selectors */ private[kafka] class Processor(val handlerMapping: Handler.HandlerMapping, - val time: Time, - val stats: SocketServerStats) extends AbstractServerThread { + val time: Time, + val stats: SocketServerStats, + val maxRequestSize: Int) extends AbstractServerThread { private val newConnections = new ConcurrentLinkedQueue[SocketChannel](); private val requestLogger = Logger.getLogger("kafka.request.logger") @@ -211,11 +213,14 @@ private[kafka] class Processor(val handlerMapping: Handler.HandlerMapping, throw new IllegalStateException("Unrecognized key state for processor thread.") } catch { case e: EOFException => { - logger.info("Closing socket for " + channelFor(key).socket.getInetAddress + ".") + logger.info("Closing socket connection to %s.".format(channelFor(key).socket.getInetAddress)) close(key) + } + case e: InvalidRequestException => { + logger.info("Closing socket connection to %s due to invalid request: %s".format(channelFor(key).socket.getInetAddress, e.getMessage)) + close(key) } case e: Throwable => { - logger.info("Closing socket for " + channelFor(key).socket.getInetAddress + " because of error") - logger.error(e, e) + logger.error("Closing socket for " + channelFor(key).socket.getInetAddress + " because of error", e) close(key) } } @@ -293,7 +298,7 @@ private[kafka] class Processor(val handlerMapping: Handler.HandlerMapping, val socketChannel = channelFor(key) var request = key.attachment.asInstanceOf[Receive] if(key.attachment == null) { - request = new BoundedByteBufferReceive() + request = new BoundedByteBufferReceive(maxRequestSize) key.attach(request) } val read = request.readFrom(socketChannel) diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index a0f1140e26f..c7c74ecdf1d 100644 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -68,7 +68,8 @@ class KafkaServer(val config: KafkaConfig) { socketServer = new SocketServer(config.port, config.numThreads, config.monitoringPeriodSecs, - handlers.handlerFor) + handlers.handlerFor, + config.maxSocketRequestSize) Utils.swallow(logger.warn, Utils.registerMBean(socketServer.stats, statsMBeanName)) socketServer.startup /** diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala new file mode 100644 index 00000000000..71b9a689087 --- /dev/null +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network; + +import java.net._ +import java.io._ +import java.nio._ +import java.nio.channels._ +import org.junit._ +import junit.framework.Assert._ +import org.scalatest.junit.JUnitSuite +import kafka.utils.TestUtils +import kafka.network._ +import java.util.Random +import org.apache.log4j._ + +class SocketServerTest extends JUnitSuite { + + Logger.getLogger("kafka").setLevel(Level.INFO) + + def echo(receive: Receive): Option[Send] = { + val id = receive.buffer.getShort + Some(new BoundedByteBufferSend(receive.buffer.slice)) + } + + val server = new SocketServer(port = TestUtils.choosePort, + numProcessorThreads = 1, + monitoringPeriodSecs = 30, + handlerFactory = (requestId: Short, receive: Receive) => echo, + maxRequestSize = 50) + server.startup() + + def sendRequest(id: Short, request: Array[Byte]): Array[Byte] = { + val socket = new Socket("localhost", server.port) + val outgoing = new DataOutputStream(socket.getOutputStream) + outgoing.writeInt(request.length + 2) + outgoing.writeShort(id) + outgoing.write(request) + outgoing.flush() + val incoming = new DataInputStream(socket.getInputStream) + val len = incoming.readInt() + val response = new Array[Byte](len) + incoming.readFully(response) + socket.close() + response + } + + @After + def cleanup() { + server.shutdown() + } + + @Test + def simpleRequest() { + val response = new String(sendRequest(0, "hello".getBytes)) + + } + + @Test(expected=classOf[EOFException]) + def tooBigRequestIsRejected() { + val tooManyBytes = new Array[Byte](server.maxRequestSize + 1) + new Random().nextBytes(tooManyBytes) + sendRequest(0, tooManyBytes) + } + +} \ No newline at end of file diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala index db0043992ae..25f6b49e8f8 100644 --- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala +++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala @@ -88,6 +88,7 @@ object TestUtils { /** * Create a kafka server instance with appropriate test settings + * USING THIS IS A SIGN YOU ARE NOT WRITING A REAL UNIT TEST * @param config The configuration of the server */ def createServer(config: KafkaConfig): KafkaServer = {