KAFKA-240 ProducerRequest wire format protocol update and related changes

git-svn-id: https://svn.apache.org/repos/asf/incubator/kafka/branches/0.8@1296577 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Joe Stein 2012-03-03 05:46:43 +00:00
parent 9fd2d4ffdb
commit f8042d5b39
20 changed files with 289 additions and 251 deletions

View File

@ -39,6 +39,15 @@ object PartitionData {
case class PartitionData(partition: Int, error: Int = ErrorMapping.NoError, initialOffset:Long = 0L, messages: MessageSet) { case class PartitionData(partition: Int, error: Int = ErrorMapping.NoError, initialOffset:Long = 0L, messages: MessageSet) {
val sizeInBytes = 4 + 4 + 8 + 4 + messages.sizeInBytes.intValue() val sizeInBytes = 4 + 4 + 8 + 4 + messages.sizeInBytes.intValue()
def this(partition: Int, messages: MessageSet) = this(partition, ErrorMapping.NoError, 0L, messages)
def getTranslatedPartition(topic: String, randomSelector: String => Int): Int = {
if (partition == ProducerRequest.RandomPartition)
return randomSelector(topic)
else
return partition
}
} }
object TopicData { object TopicData {
@ -73,6 +82,15 @@ object TopicData {
case class TopicData(topic: String, partitionData: Array[PartitionData]) { case class TopicData(topic: String, partitionData: Array[PartitionData]) {
val sizeInBytes = 2 + topic.length + partitionData.foldLeft(4)(_ + _.sizeInBytes) val sizeInBytes = 2 + topic.length + partitionData.foldLeft(4)(_ + _.sizeInBytes)
override def equals(other: Any): Boolean = {
other match {
case that: TopicData =>
( topic == that.topic &&
partitionData.toSeq == that.partitionData.toSeq )
case _ => false
}
}
} }
object FetchResponse { object FetchResponse {

View File

@ -1,57 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.api
import java.nio.ByteBuffer
import kafka.network.Request
object MultiProducerRequest {
def readFrom(buffer: ByteBuffer): MultiProducerRequest = {
val count = buffer.getShort
val produces = new Array[ProducerRequest](count)
for(i <- 0 until produces.length)
produces(i) = ProducerRequest.readFrom(buffer)
new MultiProducerRequest(produces)
}
}
class MultiProducerRequest(val produces: Array[ProducerRequest]) extends Request(RequestKeys.MultiProduce) {
def writeTo(buffer: ByteBuffer) {
if(produces.length > Short.MaxValue)
throw new IllegalArgumentException("Number of requests in MultiProducer exceeds " + Short.MaxValue + ".")
buffer.putShort(produces.length.toShort)
for(produce <- produces)
produce.writeTo(buffer)
}
def sizeInBytes: Int = {
var size = 2
for(produce <- produces)
size += produce.sizeInBytes
size
}
override def toString(): String = {
val buffer = new StringBuffer
for(produce <- produces) {
buffer.append(produce.toString)
buffer.append(",")
}
buffer.toString
}
}

View File

@ -24,60 +24,108 @@ import kafka.utils._
object ProducerRequest { object ProducerRequest {
val RandomPartition = -1 val RandomPartition = -1
val versionId: Short = 0
def readFrom(buffer: ByteBuffer): ProducerRequest = { def readFrom(buffer: ByteBuffer): ProducerRequest = {
val topic = Utils.readShortString(buffer, "UTF-8") val versionId: Short = buffer.getShort
val partition = buffer.getInt val correlationId: Int = buffer.getInt
val messageSetSize = buffer.getInt val clientId: String = Utils.readShortString(buffer, "UTF-8")
val messageSetBuffer = buffer.slice() val requiredAcks: Short = buffer.getShort
messageSetBuffer.limit(messageSetSize) val ackTimeout: Int = buffer.getInt
buffer.position(buffer.position + messageSetSize) //build the topic structure
new ProducerRequest(topic, partition, new ByteBufferMessageSet(messageSetBuffer)) val topicCount = buffer.getInt
val data = new Array[TopicData](topicCount)
for(i <- 0 until topicCount) {
val topic = Utils.readShortString(buffer, "UTF-8")
val partitionCount = buffer.getInt
//build the partition structure within this topic
val partitionData = new Array[PartitionData](partitionCount)
for (j <- 0 until partitionCount) {
val partition = buffer.getInt
val messageSetSize = buffer.getInt
val messageSetBuffer = new Array[Byte](messageSetSize)
buffer.get(messageSetBuffer,0,messageSetSize)
partitionData(j) = new PartitionData(partition,new ByteBufferMessageSet(ByteBuffer.wrap(messageSetBuffer)))
}
data(i) = new TopicData(topic,partitionData)
}
new ProducerRequest(versionId, correlationId, clientId, requiredAcks, ackTimeout, data)
} }
} }
class ProducerRequest(val topic: String, case class ProducerRequest(val versionId: Short, val correlationId: Int,
val partition: Int, val clientId: String,
val messages: ByteBufferMessageSet) extends Request(RequestKeys.Produce) { val requiredAcks: Short,
val ackTimeout: Int,
val data: Array[TopicData]) extends Request(RequestKeys.Produce) {
def this(correlationId: Int, clientId: String, requiredAcks: Short, ackTimeout: Int, data: Array[TopicData]) = this(ProducerRequest.versionId, correlationId, clientId, requiredAcks, ackTimeout, data)
def writeTo(buffer: ByteBuffer) { def writeTo(buffer: ByteBuffer) {
Utils.writeShortString(buffer, topic) buffer.putShort(versionId)
buffer.putInt(partition) buffer.putInt(correlationId)
buffer.putInt(messages.serialized.limit) Utils.writeShortString(buffer, clientId, "UTF-8")
buffer.put(messages.serialized) buffer.putShort(requiredAcks)
messages.serialized.rewind buffer.putInt(ackTimeout)
//save the topic structure
buffer.putInt(data.size) //the number of topics
data.foreach(d =>{
Utils.writeShortString(buffer, d.topic, "UTF-8") //write the topic
buffer.putInt(d.partitionData.size) //the number of partitions
d.partitionData.foreach(p => {
buffer.putInt(p.partition)
buffer.putInt(p.messages.getSerialized().limit)
buffer.put(p.messages.getSerialized())
p.messages.getSerialized().rewind
})
})
} }
def sizeInBytes(): Int = 2 + topic.length + 4 + 4 + messages.sizeInBytes.asInstanceOf[Int]
def getTranslatedPartition(randomSelector: String => Int): Int = { def sizeInBytes(): Int = {
if (partition == ProducerRequest.RandomPartition) var size = 0
return randomSelector(topic) //size, request_type_id, version_id, correlation_id, client_id, required_acks, ack_timeout, data.size
else size = 2 + 4 + 2 + clientId.length + 2 + 4 + 4;
return partition data.foreach(d =>{
size += 2 + d.topic.length + 4
d.partitionData.foreach(p => {
size += 4 + 4 + p.messages.sizeInBytes.asInstanceOf[Int]
})
})
size
} }
override def toString: String = { override def toString: String = {
val builder = new StringBuilder() val builder = new StringBuilder()
builder.append("ProducerRequest(") builder.append("ProducerRequest(")
builder.append(topic + ",") builder.append(versionId + ",")
builder.append(partition + ",") builder.append(correlationId + ",")
builder.append(messages.sizeInBytes) builder.append(clientId + ",")
builder.append(requiredAcks + ",")
builder.append(ackTimeout)
data.foreach(d =>{
builder.append(":[" + d.topic)
d.partitionData.foreach(p => {
builder.append(":[")
builder.append(p.partition + ",")
builder.append(p.messages.sizeInBytes)
builder.append("]")
})
builder.append("]")
})
builder.append(")") builder.append(")")
builder.toString builder.toString
} }
override def equals(other: Any): Boolean = { override def equals(other: Any): Boolean = {
other match { other match {
case that: ProducerRequest => case that: ProducerRequest =>
(that canEqual this) && topic == that.topic && partition == that.partition && ( correlationId == that.correlationId &&
messages.equals(that.messages) clientId == that.clientId &&
requiredAcks == that.requiredAcks &&
ackTimeout == that.ackTimeout &&
data.toSeq == that.data.toSeq)
case _ => false case _ => false
} }
} }
}
def canEqual(other: Any): Boolean = other.isInstanceOf[ProducerRequest]
override def hashCode: Int = 31 + (17 * partition) + topic.hashCode + messages.hashCode
}

View File

@ -22,7 +22,7 @@ import kafka.api.TopicData
class FetchResponse( val versionId: Short, class FetchResponse( val versionId: Short,
val correlationId: Int, val correlationId: Int,
val data: Array[TopicData] ) { private val data: Array[TopicData] ) {
private val underlying = new kafka.api.FetchResponse(versionId, correlationId, data) private val underlying = new kafka.api.FetchResponse(versionId, correlationId, data)

View File

@ -17,36 +17,29 @@
package kafka.javaapi package kafka.javaapi
import kafka.network.Request import kafka.network.Request
import kafka.api.RequestKeys import kafka.api.{RequestKeys, TopicData}
import java.nio.ByteBuffer import java.nio.ByteBuffer
class ProducerRequest(val topic: String, class ProducerRequest(val correlationId: Int,
val partition: Int, val clientId: String,
val messages: kafka.javaapi.message.ByteBufferMessageSet) extends Request(RequestKeys.Produce) { val requiredAcks: Short,
val ackTimeout: Int,
val data: Array[TopicData]) extends Request(RequestKeys.Produce) {
import Implicits._ import Implicits._
private val underlying = new kafka.api.ProducerRequest(topic, partition, messages) val underlying = new kafka.api.ProducerRequest(correlationId, clientId, requiredAcks, ackTimeout, data)
def writeTo(buffer: ByteBuffer) { underlying.writeTo(buffer) } def writeTo(buffer: ByteBuffer) { underlying.writeTo(buffer) }
def sizeInBytes(): Int = underlying.sizeInBytes def sizeInBytes(): Int = underlying.sizeInBytes
def getTranslatedPartition(randomSelector: String => Int): Int =
underlying.getTranslatedPartition(randomSelector)
override def toString: String = override def toString: String =
underlying.toString underlying.toString
override def equals(other: Any): Boolean = { override def equals(other: Any): Boolean = underlying.equals(other)
other match {
case that: ProducerRequest =>
(that canEqual this) && topic == that.topic && partition == that.partition &&
messages.equals(that.messages)
case _ => false
}
}
def canEqual(other: Any): Boolean = other.isInstanceOf[ProducerRequest] def canEqual(other: Any): Boolean = other.isInstanceOf[ProducerRequest]
override def hashCode: Int = 31 + (17 * partition) + topic.hashCode + messages.hashCode override def hashCode: Int = underlying.hashCode
} }

View File

@ -39,7 +39,7 @@ class ByteBufferMessageSet(private val buffer: ByteBuffer,
def validBytes: Long = underlying.validBytes def validBytes: Long = underlying.validBytes
def serialized():ByteBuffer = underlying.serialized def serialized():ByteBuffer = underlying.getSerialized()
def getInitialOffset = initialOffset def getInitialOffset = initialOffset

View File

@ -18,6 +18,8 @@ package kafka.javaapi.producer
import kafka.producer.SyncProducerConfig import kafka.producer.SyncProducerConfig
import kafka.javaapi.message.ByteBufferMessageSet import kafka.javaapi.message.ByteBufferMessageSet
import kafka.javaapi.ProducerRequest
import kafka.api.{PartitionData, TopicData}
class SyncProducer(syncProducer: kafka.producer.SyncProducer) { class SyncProducer(syncProducer: kafka.producer.SyncProducer) {
@ -25,21 +27,17 @@ class SyncProducer(syncProducer: kafka.producer.SyncProducer) {
val underlying = syncProducer val underlying = syncProducer
def send(topic: String, partition: Int, messages: ByteBufferMessageSet) { def send(producerRequest: kafka.javaapi.ProducerRequest) {
import kafka.javaapi.Implicits._ underlying.send(producerRequest.underlying)
underlying.send(topic, partition, messages)
} }
def send(topic: String, messages: ByteBufferMessageSet): Unit = send(topic, def send(topic: String, messages: ByteBufferMessageSet): Unit = {
kafka.api.ProducerRequest.RandomPartition, var data = new Array[TopicData](1)
messages) var partition_data = new Array[PartitionData](1)
partition_data(0) = new PartitionData(-1,messages.underlying)
def multiSend(produces: Array[kafka.javaapi.ProducerRequest]) { data(0) = new TopicData(topic,partition_data)
import kafka.javaapi.Implicits._ val producerRequest = new kafka.api.ProducerRequest(-1, "", 0, 0, data)
val produceRequests = new Array[kafka.api.ProducerRequest](produces.length) underlying.send(producerRequest)
for(i <- 0 until produces.length)
produceRequests(i) = new kafka.api.ProducerRequest(produces(i).topic, produces(i).partition, produces(i).messages)
underlying.multiSend(produceRequests)
} }
def close() { def close() {

View File

@ -53,7 +53,7 @@ class ByteBufferMessageSet(private val buffer: ByteBuffer,
def getErrorCode = errorCode def getErrorCode = errorCode
def serialized(): ByteBuffer = buffer def getSerialized(): ByteBuffer = buffer
def validBytes: Long = shallowValidBytes def validBytes: Long = shallowValidBytes

View File

@ -40,6 +40,8 @@ class FileMessageSet private[kafka](private[message] val channel: FileChannel,
private val setSize = new AtomicLong() private val setSize = new AtomicLong()
private val setHighWaterMark = new AtomicLong() private val setHighWaterMark = new AtomicLong()
def getSerialized(): ByteBuffer = throw new java.lang.UnsupportedOperationException()
if(mutable) { if(mutable) {
if(limit < Long.MaxValue || offset > 0) if(limit < Long.MaxValue || offset > 0)
throw new IllegalArgumentException("Attempt to open a mutable message set with a view or offset, which is not allowed.") throw new IllegalArgumentException("Attempt to open a mutable message set with a view or offset, which is not allowed.")

View File

@ -111,4 +111,9 @@ abstract class MessageSet extends Iterable[MessageAndOffset] {
throw new InvalidMessageException throw new InvalidMessageException
} }
/**
* Used to allow children to have serialization on implementation
*/
def getSerialized(): ByteBuffer
} }

View File

@ -51,29 +51,10 @@ class SyncProducer(val config: SyncProducerConfig) extends Logging {
if (logger.isTraceEnabled) { if (logger.isTraceEnabled) {
trace("verifying sendbuffer of size " + buffer.limit) trace("verifying sendbuffer of size " + buffer.limit)
val requestTypeId = buffer.getShort() val requestTypeId = buffer.getShort()
if (requestTypeId == RequestKeys.MultiProduce) { val request = ProducerRequest.readFrom(buffer)
try { trace(request.toString)
val request = MultiProducerRequest.readFrom(buffer)
for (produce <- request.produces) {
try {
for (messageAndOffset <- produce.messages)
if (!messageAndOffset.message.isValid)
trace("topic " + produce.topic + " is invalid")
}
catch {
case e: Throwable =>
trace("error iterating messages ", e)
}
}
}
catch {
case e: Throwable =>
trace("error verifying sendbuffer ", e)
}
}
} }
} }
/** /**
* Common functionality for the public send methods * Common functionality for the public send methods
*/ */
@ -108,21 +89,15 @@ class SyncProducer(val config: SyncProducerConfig) extends Logging {
/** /**
* Send a message * Send a message
*/ */
def send(topic: String, partition: Int, messages: ByteBufferMessageSet) { def send(producerRequest: ProducerRequest) {
verifyMessageSize(messages) producerRequest.data.foreach(d => {
val setSize = messages.sizeInBytes.asInstanceOf[Int] d.partitionData.foreach(p => {
trace("Got message set with " + setSize + " bytes to send") verifyMessageSize(new ByteBufferMessageSet(p.messages.getSerialized()))
send(new BoundedByteBufferSend(new ProducerRequest(topic, partition, messages))) val setSize = p.messages.sizeInBytes.asInstanceOf[Int]
} trace("Got message set with " + setSize + " bytes to send")
})
def send(topic: String, messages: ByteBufferMessageSet): Unit = send(topic, ProducerRequest.RandomPartition, messages) })
send(new BoundedByteBufferSend(producerRequest))
def multiSend(produces: Array[ProducerRequest]) {
for (request <- produces)
verifyMessageSize(request.messages)
val setSize = produces.foldLeft(0L)(_ + _.messages.sizeInBytes)
trace("Got multi message sets with " + setSize + " bytes to send")
send(new BoundedByteBufferSend(new MultiProducerRequest(produces)))
} }
def send(request: TopicMetadataRequest): Seq[TopicMetadata] = { def send(request: TopicMetadataRequest): Seq[TopicMetadata] = {

View File

@ -41,4 +41,23 @@ trait SyncProducerConfigShared {
val reconnectInterval = Utils.getInt(props, "reconnect.interval", 30000) val reconnectInterval = Utils.getInt(props, "reconnect.interval", 30000)
val maxMessageSize = Utils.getInt(props, "max.message.size", 1000000) val maxMessageSize = Utils.getInt(props, "max.message.size", 1000000)
/* the client application sending the producer requests */
val correlationId = Utils.getInt(props,"producer.request.correlation_id",-1)
/* the client application sending the producer requests */
val clientId = Utils.getString(props,"producer.request.client_id","")
/* the required_acks of the producer requests */
val requiredAcks = Utils.getShort(props,"producer.request.required_acks",0)
/* the ack_timeout of the producer requests */
val ackTimeout = Utils.getInt(props,"producer.request.ack_timeout",1)
} }
object SyncProducerConfig {
val DefaultCorrelationId = -1
val DefaultClientId = ""
val DefaultRequiredAcks : Short = 0
val DefaultAckTimeoutMs = 1
}

View File

@ -17,7 +17,7 @@
package kafka.producer.async package kafka.producer.async
import kafka.api.ProducerRequest import kafka.api.{ProducerRequest, TopicData, PartitionData}
import kafka.serializer.Encoder import kafka.serializer.Encoder
import kafka.producer._ import kafka.producer._
import kafka.cluster.{Partition, Broker} import kafka.cluster.{Partition, Broker}
@ -147,9 +147,22 @@ class DefaultEventHandler[K,V](config: ProducerConfig,
private def send(brokerId: Int, messagesPerTopic: Map[(String, Int), ByteBufferMessageSet]) { private def send(brokerId: Int, messagesPerTopic: Map[(String, Int), ByteBufferMessageSet]) {
if(messagesPerTopic.size > 0) { if(messagesPerTopic.size > 0) {
val requests = messagesPerTopic.map(f => new ProducerRequest(f._1._1, f._1._2, f._2)).toArray val topics = new HashMap[String, ListBuffer[PartitionData]]()
val requests = messagesPerTopic.map(f => {
val topicName = f._1._1
val partitionId = f._1._2
val messagesSet= f._2
val topic = topics.get(topicName) // checking to see if this topics exists
topic match {
case None => topics += topicName -> new ListBuffer[PartitionData]() //create a new listbuffer for this topic
case Some(x) => trace("found " + topicName)
}
topics(topicName).append(new PartitionData(partitionId, messagesSet))
})
val topicData = topics.map(kv => new TopicData(kv._1,kv._2.toArray))
val producerRequest = new ProducerRequest(config.correlationId, config.clientId, config.requiredAcks, config.ackTimeout, topicData.toArray) //new kafka.javaapi.ProducerRequest(correlation_id, client_id, required_acks, ack_timeout, topic_data.toArray)
val syncProducer = producerPool.getProducer(brokerId) val syncProducer = producerPool.getProducer(brokerId)
syncProducer.multiSend(requests) syncProducer.send(producerRequest)
trace("kafka producer sent messages for topics %s to broker %s:%d" trace("kafka producer sent messages for topics %s to broker %s:%d"
.format(messagesPerTopic, syncProducer.config.host, syncProducer.config.port)) .format(messagesPerTopic, syncProducer.config.host, syncProducer.config.port))
} }

View File

@ -41,7 +41,6 @@ class KafkaApis(val logManager: LogManager) extends Logging {
apiId match { apiId match {
case RequestKeys.Produce => handleProducerRequest(receive) case RequestKeys.Produce => handleProducerRequest(receive)
case RequestKeys.Fetch => handleFetchRequest(receive) case RequestKeys.Fetch => handleFetchRequest(receive)
case RequestKeys.MultiProduce => handleMultiProducerRequest(receive)
case RequestKeys.Offsets => handleOffsetRequest(receive) case RequestKeys.Offsets => handleOffsetRequest(receive)
case RequestKeys.TopicMetadata => handleTopicMetadataRequest(receive) case RequestKeys.TopicMetadata => handleTopicMetadataRequest(receive)
case _ => throw new IllegalStateException("No mapping found for handler id " + apiId) case _ => throw new IllegalStateException("No mapping found for handler id " + apiId)
@ -59,31 +58,38 @@ class KafkaApis(val logManager: LogManager) extends Logging {
None None
} }
def handleMultiProducerRequest(receive: Receive): Option[Send] = { private def handleProducerRequest(request: ProducerRequest, requestHandlerName: String): Option[ProducerResponse] = {
val request = MultiProducerRequest.readFrom(receive.buffer) val requestSize = request.data.size
if(requestLogger.isTraceEnabled) val errors = new Array[Int](requestSize)
requestLogger.trace("Multiproducer request " + request.toString) val offsets = new Array[Long](requestSize)
request.produces.map(handleProducerRequest(_, "MultiProducerRequest"))
None request.data.foreach(d => {
} d.partitionData.foreach(p => {
val partition = p.getTranslatedPartition(d.topic, logManager.chooseRandomPartition)
private def handleProducerRequest(request: ProducerRequest, requestHandlerName: String) = { try {
val partition = request.getTranslatedPartition(logManager.chooseRandomPartition) logManager.getOrCreateLog(d.topic, partition).append(p.messages)
try { trace(p.messages.sizeInBytes + " bytes written to logs.")
logManager.getOrCreateLog(request.topic, partition).append(request.messages) p.messages.foreach(m => trace("wrote message %s to disk".format(m.message.checksum)))
trace(request.messages.sizeInBytes + " bytes written to logs.")
} catch {
case e =>
error("Error processing " + requestHandlerName + " on " + request.topic + ":" + partition, e)
e match {
case _: IOException =>
fatal("Halting due to unrecoverable I/O error while handling producer request: " + e.getMessage, e)
System.exit(1)
case _ =>
} }
throw e catch {
} case e =>
None //TODO: handle response in ProducerResponse
error("Error processing " + requestHandlerName + " on " + d.topic + ":" + partition, e)
e match {
case _: IOException =>
fatal("Halting due to unrecoverable I/O error while handling producer request: " + e.getMessage, e)
Runtime.getRuntime.halt(1)
case _ =>
}
//throw e
}
})
//None
})
if (request.requiredAcks == 0)
None
else
None //TODO: send when KAFKA-49 can receive this Some(new ProducerResponse(request.versionId, request.correlationId, errors, offsets))
} }
def handleFetchRequest(request: Receive): Option[Send] = { def handleFetchRequest(request: Receive): Option[Send] = {

View File

@ -1,36 +0,0 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.server
import kafka.network._
import kafka.utils._
/**
* A set of message sets prefixed by size
*/
@nonthreadsafe
private[server] class MultiMessageSetSend(val sets: List[MessageSetSend]) extends MultiSend(new ByteBufferSend(6) :: sets) {
val buffer = this.sends.head.asInstanceOf[ByteBufferSend].buffer
val allMessageSetSize: Int = sets.foldLeft(0)(_ + _.sendSize)
val expectedBytesToWrite: Int = 4 + 2 + allMessageSetSize
buffer.putInt(2 + allMessageSetSize)
buffer.putShort(0)
buffer.rewind()
}

View File

@ -195,6 +195,9 @@ object Utils extends Logging {
def getInt(props: Properties, name: String, default: Int): Int = def getInt(props: Properties, name: String, default: Int): Int =
getIntInRange(props, name, default, (Int.MinValue, Int.MaxValue)) getIntInRange(props, name, default, (Int.MinValue, Int.MaxValue))
def getShort(props: Properties, name: String, default: Short): Short =
getShortInRange(props, name, default, (Short.MinValue, Short.MaxValue))
/** /**
* Read an integer from the properties instance. Throw an exception * Read an integer from the properties instance. Throw an exception
* if the value is not in the given range (inclusive) * if the value is not in the given range (inclusive)
@ -217,6 +220,18 @@ object Utils extends Logging {
v v
} }
def getShortInRange(props: Properties, name: String, default: Short, range: (Short, Short)): Short = {
val v =
if(props.containsKey(name))
props.getProperty(name).toShort
else
default
if(v < range._1 || v > range._2)
throw new IllegalArgumentException(name + " has value " + v + " which is not in the range " + range + ".")
else
v
}
def getIntInRange(buffer: ByteBuffer, name: String, range: (Int, Int)): Int = { def getIntInRange(buffer: ByteBuffer, name: String, range: (Int, Int)): Int = {
val value = buffer.getInt val value = buffer.getInt
if(value < range._1 || value > range._2) if(value < range._1 || value > range._2)
@ -777,4 +792,4 @@ class SnapshotStats(private val monitorDurationNs: Long = 600L * 1000L * 1000L *
def durationMs: Double = (end.get - start) / (1000.0 * 1000.0) def durationMs: Double = (end.get - start) / (1000.0 * 1000.0)
} }
} }

View File

@ -33,7 +33,7 @@ class ByteBufferMessageSetTest extends BaseMessageSetTestCases {
// create a ByteBufferMessageSet that doesn't contain a full message // create a ByteBufferMessageSet that doesn't contain a full message
// iterating it should get an InvalidMessageSizeException // iterating it should get an InvalidMessageSizeException
val messages = new ByteBufferMessageSet(NoCompressionCodec, new Message("01234567890123456789".getBytes())) val messages = new ByteBufferMessageSet(NoCompressionCodec, new Message("01234567890123456789".getBytes()))
val buffer = messages.serialized.slice val buffer = messages.getSerialized().slice
buffer.limit(10) buffer.limit(10)
val messageSetWithNoFullMessage = new ByteBufferMessageSet(buffer = buffer, initialOffset = 1000) val messageSetWithNoFullMessage = new ByteBufferMessageSet(buffer = buffer, initialOffset = 1000)
try { try {
@ -51,7 +51,7 @@ class ByteBufferMessageSetTest extends BaseMessageSetTestCases {
{ {
val messages = new ByteBufferMessageSet(NoCompressionCodec, new Message("hello".getBytes()), new Message("there".getBytes())) val messages = new ByteBufferMessageSet(NoCompressionCodec, new Message("hello".getBytes()), new Message("there".getBytes()))
val buffer = ByteBuffer.allocate(messages.sizeInBytes.toInt + 2) val buffer = ByteBuffer.allocate(messages.sizeInBytes.toInt + 2)
buffer.put(messages.serialized) buffer.put(messages.getSerialized())
buffer.putShort(4) buffer.putShort(4)
val messagesPlus = new ByteBufferMessageSet(buffer) val messagesPlus = new ByteBufferMessageSet(buffer)
assertEquals("Adding invalid bytes shouldn't change byte count", messages.validBytes, messagesPlus.validBytes) assertEquals("Adding invalid bytes shouldn't change byte count", messages.validBytes, messagesPlus.validBytes)
@ -93,7 +93,7 @@ class ByteBufferMessageSetTest extends BaseMessageSetTestCases {
//make sure ByteBufferMessageSet is re-iterable. //make sure ByteBufferMessageSet is re-iterable.
TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(messageSet.iterator)) TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(messageSet.iterator))
//make sure the last offset after iteration is correct //make sure the last offset after iteration is correct
assertEquals("offset of last message not expected", messageSet.last.offset, messageSet.serialized.limit) assertEquals("offset of last message not expected", messageSet.last.offset, messageSet.getSerialized().limit)
} }
// test for compressed regular messages // test for compressed regular messages
@ -103,7 +103,7 @@ class ByteBufferMessageSetTest extends BaseMessageSetTestCases {
//make sure ByteBufferMessageSet is re-iterable. //make sure ByteBufferMessageSet is re-iterable.
TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(messageSet.iterator)) TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(messageSet.iterator))
//make sure the last offset after iteration is correct //make sure the last offset after iteration is correct
assertEquals("offset of last message not expected", messageSet.last.offset, messageSet.serialized.limit) assertEquals("offset of last message not expected", messageSet.last.offset, messageSet.getSerialized().limit)
} }
// test for mixed empty and non-empty messagesets uncompressed // test for mixed empty and non-empty messagesets uncompressed
@ -111,16 +111,16 @@ class ByteBufferMessageSetTest extends BaseMessageSetTestCases {
val emptyMessageList : List[Message] = Nil val emptyMessageList : List[Message] = Nil
val emptyMessageSet = new ByteBufferMessageSet(NoCompressionCodec, emptyMessageList: _*) val emptyMessageSet = new ByteBufferMessageSet(NoCompressionCodec, emptyMessageList: _*)
val regularMessgeSet = new ByteBufferMessageSet(NoCompressionCodec, messageList: _*) val regularMessgeSet = new ByteBufferMessageSet(NoCompressionCodec, messageList: _*)
val buffer = ByteBuffer.allocate(emptyMessageSet.serialized.limit + regularMessgeSet.serialized.limit) val buffer = ByteBuffer.allocate(emptyMessageSet.getSerialized().limit + regularMessgeSet.getSerialized().limit)
buffer.put(emptyMessageSet.serialized) buffer.put(emptyMessageSet.getSerialized())
buffer.put(regularMessgeSet.serialized) buffer.put(regularMessgeSet.getSerialized())
buffer.rewind buffer.rewind
val mixedMessageSet = new ByteBufferMessageSet(buffer, 0, 0) val mixedMessageSet = new ByteBufferMessageSet(buffer, 0, 0)
TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(mixedMessageSet.iterator)) TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(mixedMessageSet.iterator))
//make sure ByteBufferMessageSet is re-iterable. //make sure ByteBufferMessageSet is re-iterable.
TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(mixedMessageSet.iterator)) TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(mixedMessageSet.iterator))
//make sure the last offset after iteration is correct //make sure the last offset after iteration is correct
assertEquals("offset of last message not expected", mixedMessageSet.last.offset, mixedMessageSet.serialized.limit) assertEquals("offset of last message not expected", mixedMessageSet.last.offset, mixedMessageSet.getSerialized().limit)
} }
// test for mixed empty and non-empty messagesets compressed // test for mixed empty and non-empty messagesets compressed
@ -128,16 +128,16 @@ class ByteBufferMessageSetTest extends BaseMessageSetTestCases {
val emptyMessageList : List[Message] = Nil val emptyMessageList : List[Message] = Nil
val emptyMessageSet = new ByteBufferMessageSet(DefaultCompressionCodec, emptyMessageList: _*) val emptyMessageSet = new ByteBufferMessageSet(DefaultCompressionCodec, emptyMessageList: _*)
val regularMessgeSet = new ByteBufferMessageSet(DefaultCompressionCodec, messageList: _*) val regularMessgeSet = new ByteBufferMessageSet(DefaultCompressionCodec, messageList: _*)
val buffer = ByteBuffer.allocate(emptyMessageSet.serialized.limit + regularMessgeSet.serialized.limit) val buffer = ByteBuffer.allocate(emptyMessageSet.getSerialized().limit + regularMessgeSet.getSerialized().limit)
buffer.put(emptyMessageSet.serialized) buffer.put(emptyMessageSet.getSerialized())
buffer.put(regularMessgeSet.serialized) buffer.put(regularMessgeSet.getSerialized())
buffer.rewind buffer.rewind
val mixedMessageSet = new ByteBufferMessageSet(buffer, 0, 0) val mixedMessageSet = new ByteBufferMessageSet(buffer, 0, 0)
TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(mixedMessageSet.iterator)) TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(mixedMessageSet.iterator))
//make sure ByteBufferMessageSet is re-iterable. //make sure ByteBufferMessageSet is re-iterable.
TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(mixedMessageSet.iterator)) TestUtils.checkEquals[Message](messageList.iterator, TestUtils.getMessageIterator(mixedMessageSet.iterator))
//make sure the last offset after iteration is correct //make sure the last offset after iteration is correct
assertEquals("offset of last message not expected", mixedMessageSet.last.offset, mixedMessageSet.serialized.limit) assertEquals("offset of last message not expected", mixedMessageSet.last.offset, mixedMessageSet.getSerialized().limit)
} }
} }

View File

@ -381,11 +381,12 @@ class AsyncProducerTest extends JUnit3Suite with ZooKeeperTestHarness {
val mockSyncProducer = EasyMock.createMock(classOf[SyncProducer]) val mockSyncProducer = EasyMock.createMock(classOf[SyncProducer])
mockSyncProducer.send(new TopicMetadataRequest(List(topic))) mockSyncProducer.send(new TopicMetadataRequest(List(topic)))
EasyMock.expectLastCall().andReturn(List(topic1Metadata)) EasyMock.expectLastCall().andReturn(List(topic1Metadata))
mockSyncProducer.multiSend(EasyMock.aryEq(Array(new ProducerRequest(topic, 0, messagesToSet(msgs.take(5)))))) mockSyncProducer.send(TestUtils.produceRequest(topic, 0,
messagesToSet(msgs.take(5))))
EasyMock.expectLastCall EasyMock.expectLastCall
mockSyncProducer.multiSend(EasyMock.aryEq(Array(new ProducerRequest(topic, 0, messagesToSet(msgs.takeRight(5)))))) mockSyncProducer.send(TestUtils.produceRequest(topic, 0,
EasyMock.expectLastCall messagesToSet(msgs.takeRight(5))))
EasyMock.replay(mockSyncProducer) EasyMock.replay(mockSyncProducer)
val producerPool = EasyMock.createMock(classOf[ProducerPool]) val producerPool = EasyMock.createMock(classOf[ProducerPool])
producerPool.getZkClient producerPool.getZkClient
@ -495,10 +496,7 @@ class AsyncProducerTest extends JUnit3Suite with ZooKeeperTestHarness {
} }
class MockProducer(override val config: SyncProducerConfig) extends SyncProducer(config) { class MockProducer(override val config: SyncProducerConfig) extends SyncProducer(config) {
override def send(topic: String, messages: ByteBufferMessageSet): Unit = { override def send(produceRequest: ProducerRequest): Unit = {
Thread.sleep(1000)
}
override def multiSend(produces: Array[ProducerRequest]) {
Thread.sleep(1000) Thread.sleep(1000)
} }
} }

View File

@ -44,7 +44,7 @@ class SyncProducerTest extends JUnit3Suite with KafkaServerTestHarness {
var failed = false var failed = false
val firstStart = SystemTime.milliseconds val firstStart = SystemTime.milliseconds
try { try {
producer.send("test", 0, new ByteBufferMessageSet(compressionCodec = NoCompressionCodec, messages = new Message(messageBytes))) producer.send(TestUtils.produceRequest("test", 0, new ByteBufferMessageSet(compressionCodec = NoCompressionCodec, messages = new Message(messageBytes))))
}catch { }catch {
case e: Exception => failed=true case e: Exception => failed=true
} }
@ -54,7 +54,7 @@ class SyncProducerTest extends JUnit3Suite with KafkaServerTestHarness {
Assert.assertTrue((firstEnd-firstStart) < 500) Assert.assertTrue((firstEnd-firstStart) < 500)
val secondStart = SystemTime.milliseconds val secondStart = SystemTime.milliseconds
try { try {
producer.send("test", 0, new ByteBufferMessageSet(compressionCodec = NoCompressionCodec, messages = new Message(messageBytes))) producer.send(TestUtils.produceRequest("test", 0, new ByteBufferMessageSet(compressionCodec = NoCompressionCodec, messages = new Message(messageBytes))))
}catch { }catch {
case e: Exception => failed = true case e: Exception => failed = true
} }
@ -63,7 +63,7 @@ class SyncProducerTest extends JUnit3Suite with KafkaServerTestHarness {
Assert.assertTrue((secondEnd-secondStart) < 500) Assert.assertTrue((secondEnd-secondStart) < 500)
try { try {
producer.multiSend(Array(new ProducerRequest("test", 0, new ByteBufferMessageSet(compressionCodec = NoCompressionCodec, messages = new Message(messageBytes))))) producer.send(TestUtils.produceRequest("test", 0, new ByteBufferMessageSet(compressionCodec = NoCompressionCodec, messages = new Message(messageBytes))))
}catch { }catch {
case e: Exception => failed=true case e: Exception => failed=true
} }
@ -83,7 +83,7 @@ class SyncProducerTest extends JUnit3Suite with KafkaServerTestHarness {
val bytes = new Array[Byte](101) val bytes = new Array[Byte](101)
var failed = false var failed = false
try { try {
producer.send("test", 0, new ByteBufferMessageSet(compressionCodec = NoCompressionCodec, messages = new Message(bytes))) producer.send(TestUtils.produceRequest("test", 0, new ByteBufferMessageSet(compressionCodec = NoCompressionCodec, messages = new Message(bytes))))
}catch { }catch {
case e: MessageSizeTooLargeException => failed = true case e: MessageSizeTooLargeException => failed = true
} }

View File

@ -33,6 +33,7 @@ import collection.mutable.ListBuffer
import kafka.consumer.{KafkaMessageStream, ConsumerConfig} import kafka.consumer.{KafkaMessageStream, ConsumerConfig}
import scala.collection.Map import scala.collection.Map
import kafka.serializer.Encoder import kafka.serializer.Encoder
import kafka.api.{ProducerRequest, TopicData, PartitionData}
/** /**
* Utility functions to help with testing * Utility functions to help with testing
@ -336,7 +337,47 @@ object TestUtils {
buffer += ("msg" + i) buffer += ("msg" + i)
buffer buffer
} }
/**
* Create a wired format request based on simple basic information
*/
def produceRequest(topic: String, message: ByteBufferMessageSet): kafka.api.ProducerRequest = {
produceRequest(SyncProducerConfig.DefaultCorrelationId,topic,ProducerRequest.RandomPartition,message)
}
def produceRequest(topic: String, partition: Int, message: ByteBufferMessageSet): kafka.api.ProducerRequest = {
produceRequest(SyncProducerConfig.DefaultCorrelationId,topic,partition,message)
}
def produceRequest(correlationId: Int, topic: String, partition: Int, message: ByteBufferMessageSet): kafka.api.ProducerRequest = {
val clientId = SyncProducerConfig.DefaultClientId
val requiredAcks: Short = SyncProducerConfig.DefaultRequiredAcks
val ackTimeout = SyncProducerConfig.DefaultAckTimeoutMs
var data = new Array[TopicData](1)
var partitionData = new Array[PartitionData](1)
partitionData(0) = new PartitionData(partition,message)
data(0) = new TopicData(topic,partitionData)
val pr = new kafka.api.ProducerRequest(correlationId, clientId, requiredAcks, ackTimeout, data)
pr
}
def produceJavaRequest(topic: String, message: kafka.javaapi.message.ByteBufferMessageSet): kafka.javaapi.ProducerRequest = {
produceJavaRequest(-1,topic,-1,message)
}
def produceJavaRequest(topic: String, partition: Int, message: kafka.javaapi.message.ByteBufferMessageSet): kafka.javaapi.ProducerRequest = {
produceJavaRequest(-1,topic,partition,message)
}
def produceJavaRequest(correlationId: Int, topic: String, partition: Int, message: kafka.javaapi.message.ByteBufferMessageSet): kafka.javaapi.ProducerRequest = {
val clientId = "test"
val requiredAcks: Short = 0
val ackTimeout = 0
var data = new Array[TopicData](1)
var partitionData = new Array[PartitionData](1)
partitionData(0) = new PartitionData(partition,message.underlying)
data(0) = new TopicData(topic,partitionData)
val pr = new kafka.javaapi.ProducerRequest(correlationId, clientId, requiredAcks, ackTimeout, data)
pr
}
} }
object TestZKUtils { object TestZKUtils {