kafka-527; Compression support does numerous byte copies; patched by Yasuhiro Matsuda; reviewed by Guozhang Wang and Jun Rao

This commit is contained in:
Yasuhiro Matsuda 2015-03-25 13:08:38 -07:00 committed by Jun Rao
parent eb2100876b
commit a74688de46
3 changed files with 370 additions and 31 deletions

View File

@ -5,7 +5,7 @@
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
@ -20,12 +20,12 @@ package kafka.message
import kafka.utils.Logging
import java.nio.ByteBuffer
import java.nio.channels._
import java.io.{InputStream, ByteArrayOutputStream, DataOutputStream}
import java.io.{InputStream, DataOutputStream}
import java.util.concurrent.atomic.AtomicLong
import kafka.utils.IteratorTemplate
object ByteBufferMessageSet {
private def create(offsetCounter: AtomicLong, compressionCodec: CompressionCodec, messages: Message*): ByteBuffer = {
if(messages.size == 0) {
MessageSet.Empty.buffer
@ -36,52 +36,55 @@ object ByteBufferMessageSet {
buffer.rewind()
buffer
} else {
val byteArrayStream = new ByteArrayOutputStream(MessageSet.messageSetSize(messages))
val output = new DataOutputStream(CompressionFactory(compressionCodec, byteArrayStream))
var offset = -1L
try {
for(message <- messages) {
offset = offsetCounter.getAndIncrement
output.writeLong(offset)
output.writeInt(message.size)
output.write(message.buffer.array, message.buffer.arrayOffset, message.buffer.limit)
val messageWriter = new MessageWriter(math.min(math.max(MessageSet.messageSetSize(messages) / 2, 1024), 1 << 16))
messageWriter.write(codec = compressionCodec) { outputStream =>
val output = new DataOutputStream(CompressionFactory(compressionCodec, outputStream))
try {
for (message <- messages) {
offset = offsetCounter.getAndIncrement
output.writeLong(offset)
output.writeInt(message.size)
output.write(message.buffer.array, message.buffer.arrayOffset, message.buffer.limit)
}
} finally {
output.close()
}
} finally {
output.close()
}
val bytes = byteArrayStream.toByteArray
val message = new Message(bytes, compressionCodec)
val buffer = ByteBuffer.allocate(message.size + MessageSet.LogOverhead)
writeMessage(buffer, message, offset)
val buffer = ByteBuffer.allocate(messageWriter.size + MessageSet.LogOverhead)
writeMessage(buffer, messageWriter, offset)
buffer.rewind()
buffer
}
}
def decompress(message: Message): ByteBufferMessageSet = {
val outputStream: ByteArrayOutputStream = new ByteArrayOutputStream
val outputStream = new BufferingOutputStream(math.min(math.max(message.size, 1024), 1 << 16))
val inputStream: InputStream = new ByteBufferBackedInputStream(message.payload)
val intermediateBuffer = new Array[Byte](1024)
val compressed = CompressionFactory(message.compressionCodec, inputStream)
try {
Stream.continually(compressed.read(intermediateBuffer)).takeWhile(_ > 0).foreach { dataRead =>
outputStream.write(intermediateBuffer, 0, dataRead)
}
outputStream.write(compressed)
} finally {
compressed.close()
}
val outputBuffer = ByteBuffer.allocate(outputStream.size)
outputBuffer.put(outputStream.toByteArray)
outputStream.writeTo(outputBuffer)
outputBuffer.rewind
new ByteBufferMessageSet(outputBuffer)
}
private[kafka] def writeMessage(buffer: ByteBuffer, message: Message, offset: Long) {
buffer.putLong(offset)
buffer.putInt(message.size)
buffer.put(message.buffer)
message.buffer.rewind()
}
private[kafka] def writeMessage(buffer: ByteBuffer, messageWriter: MessageWriter, offset: Long) {
buffer.putLong(offset)
buffer.putInt(messageWriter.size)
messageWriter.writeTo(buffer)
}
}
/**
@ -92,7 +95,7 @@ object ByteBufferMessageSet {
* Option 1: From a ByteBuffer which already contains the serialized message set. Consumers will use this method.
*
* Option 2: Give it a list of messages along with instructions relating to serialization format. Producers will use this method.
*
*
*/
class ByteBufferMessageSet(val buffer: ByteBuffer) extends MessageSet with Logging {
private var shallowValidByteCount = -1
@ -100,7 +103,7 @@ class ByteBufferMessageSet(val buffer: ByteBuffer) extends MessageSet with Loggi
def this(compressionCodec: CompressionCodec, messages: Message*) {
this(ByteBufferMessageSet.create(new AtomicLong(0), compressionCodec, messages:_*))
}
def this(compressionCodec: CompressionCodec, offsetCounter: AtomicLong, messages: Message*) {
this(ByteBufferMessageSet.create(offsetCounter, compressionCodec, messages:_*))
}
@ -123,7 +126,7 @@ class ByteBufferMessageSet(val buffer: ByteBuffer) extends MessageSet with Loggi
}
shallowValidByteCount
}
/** Write the messages in this set to the given channel */
def writeTo(channel: GatheringByteChannel, offset: Long, size: Int): Int = {
// Ignore offset and size from input. We just want to write the whole buffer to the channel.
@ -157,11 +160,11 @@ class ByteBufferMessageSet(val buffer: ByteBuffer) extends MessageSet with Loggi
val size = topIter.getInt()
if(size < Message.MinHeaderSize)
throw new InvalidMessageException("Message found with corrupt size (" + size + ")")
// we have an incomplete message
if(topIter.remaining < size)
return allDone()
// read the current message and check correctness
val message = topIter.slice()
message.limit(size)
@ -261,7 +264,7 @@ class ByteBufferMessageSet(val buffer: ByteBuffer) extends MessageSet with Loggi
*/
override def equals(other: Any): Boolean = {
other match {
case that: ByteBufferMessageSet =>
case that: ByteBufferMessageSet =>
buffer.equals(that.buffer)
case _ => false
}

View File

@ -0,0 +1,206 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.message
import java.io.{InputStream, OutputStream}
import java.nio.ByteBuffer
import kafka.utils.Crc32
class MessageWriter(segmentSize: Int) extends BufferingOutputStream(segmentSize) {
import Message._
def write(key: Array[Byte] = null, codec: CompressionCodec)(writePayload: OutputStream => Unit): Unit = {
withCrc32Prefix {
write(CurrentMagicValue)
var attributes: Byte = 0
if (codec.codec > 0)
attributes = (attributes | (CompressionCodeMask & codec.codec)).toByte
write(attributes)
// write the key
if (key == null) {
writeInt(-1)
} else {
writeInt(key.length)
write(key, 0, key.length)
}
// write the payload with length prefix
withLengthPrefix {
writePayload(this)
}
}
}
private def writeInt(value: Int): Unit = {
write(value >>> 24)
write(value >>> 16)
write(value >>> 8)
write(value)
}
private def writeInt(out: ReservedOutput, value: Int): Unit = {
out.write(value >>> 24)
out.write(value >>> 16)
out.write(value >>> 8)
out.write(value)
}
private def withCrc32Prefix(writeData: => Unit): Unit = {
// get a writer for CRC value
val crcWriter = reserve(CrcLength)
// save current position
var seg = currentSegment
val offset = currentSegment.written
// write data
writeData
// compute CRC32
val crc = new Crc32()
if (offset < seg.written) crc.update(seg.bytes, offset, seg.written - offset)
seg = seg.next
while (seg != null) {
if (seg.written > 0) crc.update(seg.bytes, 0, seg.written)
seg = seg.next
}
// write CRC32
writeInt(crcWriter, crc.getValue().toInt)
}
private def withLengthPrefix(writeData: => Unit): Unit = {
// get a writer for length value
val lengthWriter = reserve(ValueSizeLength)
// save current size
val oldSize = size
// write data
writeData
// write length value
writeInt(lengthWriter, size - oldSize)
}
}
/*
* OutputStream that buffers incoming data in segmented byte arrays
* This does not copy data when expanding its capacity
* It has a ability to
* - write data directly to ByteBuffer
* - copy data from an input stream to interval segmented arrays directly
* - hold a place holder for an unknown value that can be filled in later
*/
class BufferingOutputStream(segmentSize: Int) extends OutputStream {
protected final class Segment(size: Int) {
val bytes = new Array[Byte](size)
var written = 0
var next: Segment = null
def freeSpace: Int = bytes.length - written
}
protected class ReservedOutput(seg: Segment, offset: Int, length: Int) extends OutputStream {
private[this] var cur = seg
private[this] var off = offset
private[this] var len = length
override def write(value: Int) = {
if (len <= 0) throw new IndexOutOfBoundsException()
if (cur.bytes.length <= off) {
cur = cur.next
off = 0
}
cur.bytes(off) = value.toByte
off += 1
len -= 1
}
}
protected var currentSegment = new Segment(segmentSize)
private[this] val headSegment = currentSegment
private[this] var filled = 0
def size(): Int = filled + currentSegment.written
override def write(b: Int): Unit = {
if (currentSegment.freeSpace <= 0) addSegment()
currentSegment.bytes(currentSegment.written) = b.toByte
currentSegment.written += 1
}
override def write(b: Array[Byte], off: Int, len: Int) {
if (off >= 0 && off <= b.length && len >= 0 && off + len <= b.length) {
var remaining = len
var offset = off
while (remaining > 0) {
if (currentSegment.freeSpace <= 0) addSegment()
val amount = math.min(currentSegment.freeSpace, remaining)
System.arraycopy(b, offset, currentSegment.bytes, currentSegment.written, amount)
currentSegment.written += amount
offset += amount
remaining -= amount
}
} else {
throw new IndexOutOfBoundsException()
}
}
def write(in: InputStream): Unit = {
var amount = 0
while (amount >= 0) {
currentSegment.written += amount
if (currentSegment.freeSpace <= 0) addSegment()
amount = in.read(currentSegment.bytes, currentSegment.written, currentSegment.freeSpace)
}
}
private def addSegment() = {
filled += currentSegment.written
val newSeg = new Segment(segmentSize)
currentSegment.next = newSeg
currentSegment = newSeg
}
private def skip(len: Int): Unit = {
if (len >= 0) {
var remaining = len
while (remaining > 0) {
if (currentSegment.freeSpace <= 0) addSegment()
val amount = math.min(currentSegment.freeSpace, remaining)
currentSegment.written += amount
remaining -= amount
}
} else {
throw new IndexOutOfBoundsException()
}
}
def reserve(len: Int): ReservedOutput = {
val out = new ReservedOutput(currentSegment, currentSegment.written, len)
skip(len)
out
}
def writeTo(buffer: ByteBuffer): Unit = {
var seg = headSegment
while (seg != null) {
buffer.put(seg.bytes, 0, seg.written)
seg = seg.next
}
}
}

View File

@ -0,0 +1,130 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kafka.message
import java.io.{InputStream, ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer
import java.util.Random
import junit.framework.Assert._
import org.junit.Test
import org.scalatest.junit.JUnitSuite
class MessageWriterTest extends JUnitSuite {
private val rnd = new Random()
private def mkRandomArray(size: Int): Array[Byte] = {
(0 until size).map(_ => rnd.nextInt(10).toByte).toArray
}
private def mkMessageWithWriter(key: Array[Byte] = null, bytes: Array[Byte], codec: CompressionCodec): Message = {
val writer = new MessageWriter(100)
writer.write(key = key, codec = codec) { output =>
val out = if (codec == NoCompressionCodec) output else CompressionFactory(codec, output)
try {
val p = rnd.nextInt(bytes.length)
out.write(bytes, 0, p)
out.write(bytes, p, bytes.length - p)
} finally {
out.close()
}
}
val bb = ByteBuffer.allocate(writer.size)
writer.writeTo(bb)
bb.rewind()
new Message(bb)
}
private def compress(bytes: Array[Byte], codec: CompressionCodec): Array[Byte] = {
val baos = new ByteArrayOutputStream()
val out = CompressionFactory(codec, baos)
out.write(bytes)
out.close()
baos.toByteArray
}
private def decompress(compressed: Array[Byte], codec: CompressionCodec): Array[Byte] = {
toArray(CompressionFactory(codec, new ByteArrayInputStream(compressed)))
}
private def toArray(in: InputStream): Array[Byte] = {
val out = new ByteArrayOutputStream()
val buf = new Array[Byte](100)
var amount = in.read(buf)
while (amount >= 0) {
out.write(buf, 0, amount)
amount = in.read(buf)
}
out.toByteArray
}
private def toArray(bb: ByteBuffer): Array[Byte] = {
val arr = new Array[Byte](bb.limit())
bb.get(arr)
bb.rewind()
arr
}
@Test
def testBufferingOutputStream(): Unit = {
val out = new BufferingOutputStream(50)
out.write(0)
out.write(1)
out.write(2)
val r = out.reserve(100)
out.write((103 until 200).map(_.toByte).toArray)
r.write((3 until 103).map(_.toByte).toArray)
val buf = ByteBuffer.allocate(out.size)
out.writeTo(buf)
buf.rewind()
assertEquals((0 until 200).map(_.toByte), buf.array.toSeq)
}
@Test
def testWithNoCompressionAttribute(): Unit = {
val bytes = mkRandomArray(4096)
val actual = mkMessageWithWriter(bytes = bytes, codec = NoCompressionCodec)
val expected = new Message(bytes, NoCompressionCodec)
assertEquals(expected.buffer, actual.buffer)
}
@Test
def testWithCompressionAttribute(): Unit = {
val bytes = mkRandomArray(4096)
val actual = mkMessageWithWriter(bytes = bytes, codec = SnappyCompressionCodec)
val expected = new Message(compress(bytes, SnappyCompressionCodec), SnappyCompressionCodec)
assertEquals(
decompress(toArray(expected.payload), SnappyCompressionCodec).toSeq,
decompress(toArray(actual.payload), SnappyCompressionCodec).toSeq
)
}
@Test
def testWithKey(): Unit = {
val key = mkRandomArray(123)
val bytes = mkRandomArray(4096)
val actual = mkMessageWithWriter(bytes = bytes, key = key, codec = NoCompressionCodec)
val expected = new Message(bytes = bytes, key = key, codec = NoCompressionCodec)
assertEquals(expected.buffer, actual.buffer)
}
}