diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml index ef3947ad91a..55fcd1a9e53 100644 --- a/checkstyle/suppressions.xml +++ b/checkstyle/suppressions.xml @@ -170,6 +170,10 @@ + + + diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java b/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java index bd0925d6db3..f643f5b5779 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java @@ -54,8 +54,15 @@ public class ByteBufferAccessor implements Readable, Writable { } @Override - public void readArray(byte[] arr) { + public byte[] readArray(int size) { + int remaining = buf.remaining(); + if (size > remaining) { + throw new RuntimeException("Error reading byte array of " + size + " byte(s): only " + remaining + + " byte(s) available"); + } + byte[] arr = new byte[size]; buf.get(arr); + return arr; } @Override diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/DataInputStreamReadable.java b/clients/src/main/java/org/apache/kafka/common/protocol/DataInputStreamReadable.java deleted file mode 100644 index 70ed52d6a02..00000000000 --- a/clients/src/main/java/org/apache/kafka/common/protocol/DataInputStreamReadable.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.kafka.common.protocol; - -import org.apache.kafka.common.utils.ByteUtils; - -import java.io.Closeable; -import java.io.DataInputStream; -import java.io.IOException; -import java.nio.ByteBuffer; - -public class DataInputStreamReadable implements Readable, Closeable { - protected final DataInputStream input; - - public DataInputStreamReadable(DataInputStream input) { - this.input = input; - } - - @Override - public byte readByte() { - try { - return input.readByte(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public short readShort() { - try { - return input.readShort(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public int readInt() { - try { - return input.readInt(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public long readLong() { - try { - return input.readLong(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public double readDouble() { - try { - return input.readDouble(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public void readArray(byte[] arr) { - try { - input.readFully(arr); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public int readUnsignedVarint() { - try { - return ByteUtils.readUnsignedVarint(input); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public ByteBuffer readByteBuffer(int length) { - byte[] arr = new byte[length]; - readArray(arr); - return ByteBuffer.wrap(arr); - } - - @Override - public int readVarint() { - try { - return ByteUtils.readVarint(input); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public long readVarlong() { - try { - return ByteUtils.readVarlong(input); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public int remaining() { - try { - return input.available(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public void close() { - try { - input.close(); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - -} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java b/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java index 561696827df..80bee867482 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java @@ -32,7 +32,7 @@ public interface Readable { int readInt(); long readLong(); double readDouble(); - void readArray(byte[] arr); + byte[] readArray(int length); int readUnsignedVarint(); ByteBuffer readByteBuffer(int length); int readVarint(); @@ -40,8 +40,7 @@ public interface Readable { int remaining(); default String readString(int length) { - byte[] arr = new byte[length]; - readArray(arr); + byte[] arr = readArray(length); return new String(arr, StandardCharsets.UTF_8); } @@ -49,8 +48,7 @@ public interface Readable { if (unknowns == null) { unknowns = new ArrayList<>(); } - byte[] data = new byte[size]; - readArray(data); + byte[] data = readArray(size); unknowns.add(new RawTaggedField(tag, data)); return unknowns; } diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java index 8772556b1de..b2235fef490 100644 --- a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java +++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java @@ -342,6 +342,8 @@ public class DefaultRecord implements Record { int numHeaders = ByteUtils.readVarint(buffer); if (numHeaders < 0) throw new InvalidRecordException("Found invalid number of record headers " + numHeaders); + if (numHeaders > buffer.remaining()) + throw new InvalidRecordException("Found invalid number of record headers. " + numHeaders + " is larger than the remaining size of the buffer"); final Header[] headers; if (numHeaders == 0) diff --git a/clients/src/test/java/org/apache/kafka/common/message/SimpleArraysMessageTest.java b/clients/src/test/java/org/apache/kafka/common/message/SimpleArraysMessageTest.java new file mode 100644 index 00000000000..1b78adbb962 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/message/SimpleArraysMessageTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.message; + +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class SimpleArraysMessageTest { + @Test + public void testArrayBoundsChecking() { + // SimpleArraysMessageData takes 2 arrays + final ByteBuffer buf = ByteBuffer.wrap(new byte[] { + (byte) 0x7f, // Set size of first array to 126 which is larger than the size of this buffer + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00 + }); + final SimpleArraysMessageData out = new SimpleArraysMessageData(); + ByteBufferAccessor accessor = new ByteBufferAccessor(buf); + assertEquals("Tried to allocate a collection of size 126, but there are only 7 bytes remaining.", + assertThrows(RuntimeException.class, () -> out.read(accessor, (short) 2)).getMessage()); + } + + @Test + public void testArrayBoundsCheckingOtherArray() { + // SimpleArraysMessageData takes 2 arrays + final ByteBuffer buf = ByteBuffer.wrap(new byte[] { + (byte) 0x01, // Set size of first array to 0 + (byte) 0x7e, // Set size of second array to 125 which is larger than the size of this buffer + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00 + }); + final SimpleArraysMessageData out = new SimpleArraysMessageData(); + ByteBufferAccessor accessor = new ByteBufferAccessor(buf); + assertEquals("Tried to allocate a collection of size 125, but there are only 6 bytes remaining.", + assertThrows(RuntimeException.class, () -> out.read(accessor, (short) 2)).getMessage()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/ByteBufferAccessorTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/ByteBufferAccessorTest.java new file mode 100644 index 00000000000..6a0c6c2681c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/ByteBufferAccessorTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ByteBufferAccessorTest { + @Test + public void testReadArray() { + ByteBuffer buf = ByteBuffer.allocate(1024); + ByteBufferAccessor accessor = new ByteBufferAccessor(buf); + final byte[] testArray = new byte[] {0x4b, 0x61, 0x46}; + accessor.writeByteArray(testArray); + accessor.writeInt(12345); + accessor.flip(); + final byte[] testArray2 = accessor.readArray(3); + assertArrayEquals(testArray, testArray2); + assertEquals(12345, accessor.readInt()); + assertEquals("Error reading byte array of 3 byte(s): only 0 byte(s) available", + assertThrows(RuntimeException.class, + () -> accessor.readArray(3)).getMessage()); + } + + @Test + public void testReadString() { + ByteBuffer buf = ByteBuffer.allocate(1024); + ByteBufferAccessor accessor = new ByteBufferAccessor(buf); + String testString = "ABC"; + final byte[] testArray = testString.getBytes(StandardCharsets.UTF_8); + accessor.writeByteArray(testArray); + accessor.flip(); + assertEquals("ABC", accessor.readString(3)); + assertEquals("Error reading byte array of 2 byte(s): only 0 byte(s) available", + assertThrows(RuntimeException.class, + () -> accessor.readString(2)).getMessage()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java index 49743d23201..67212165fc3 100644 --- a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java +++ b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java @@ -247,6 +247,20 @@ public class DefaultRecordTest { buf.flip(); assertThrows(InvalidRecordException.class, () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + + ByteBuffer buf2 = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf2); + buf2.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf2); + ByteUtils.writeVarint(offsetDelta, buf2); + ByteUtils.writeVarint(-1, buf2); // null key + ByteUtils.writeVarint(-1, buf2); // null value + ByteUtils.writeVarint(sizeOfBodyInBytes, buf2); // more headers than remaining buffer size, not allowed + buf2.position(buf2.limit()); + + buf2.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf2, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); } @Test diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java index 4415ff960aa..254dea0430e 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java @@ -16,22 +16,31 @@ */ package org.apache.kafka.common.requests; +import org.apache.kafka.common.errors.InvalidRequestException; import org.apache.kafka.common.message.ApiVersionsResponseData; import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection; import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.message.SaslAuthenticateRequestData; import org.apache.kafka.common.network.ClientInformation; import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.network.Send; import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.protocol.ByteBufferAccessor; import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.ObjectSerializationCache; import org.apache.kafka.common.security.auth.KafkaPrincipal; import org.apache.kafka.common.security.auth.SecurityProtocol; import org.junit.jupiter.api.Test; import java.net.InetAddress; +import java.net.UnknownHostException; import java.nio.ByteBuffer; +import java.util.Collections; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class RequestContextTest { @@ -104,4 +113,78 @@ public class RequestContextTest { assertEquals(expectedResponse, parsedResponse.data()); } + @Test + public void testInvalidRequestForImplicitHashCollection() throws UnknownHostException { + short version = (short) 5; // choose a version with fixed length encoding, for simplicity + ByteBuffer corruptBuffer = produceRequest(version); + // corrupt the length of the topics array + corruptBuffer.putInt(8, (Integer.MAX_VALUE - 1) / 2); + + RequestHeader header = new RequestHeader(ApiKeys.PRODUCE, version, "console-producer", 3); + RequestContext context = new RequestContext(header, "0", InetAddress.getLocalHost(), + KafkaPrincipal.ANONYMOUS, new ListenerName("ssl"), SecurityProtocol.SASL_SSL, + ClientInformation.EMPTY, true); + + String msg = assertThrows(InvalidRequestException.class, + () -> context.parseRequest(corruptBuffer)).getCause().getMessage(); + assertEquals("Tried to allocate a collection of size 1073741823, but there are only 17 bytes remaining.", msg); + } + + @Test + public void testInvalidRequestForArrayList() throws UnknownHostException { + short version = (short) 5; // choose a version with fixed length encoding, for simplicity + ByteBuffer corruptBuffer = produceRequest(version); + // corrupt the length of the partitions array + corruptBuffer.putInt(17, Integer.MAX_VALUE); + + RequestHeader header = new RequestHeader(ApiKeys.PRODUCE, version, "console-producer", 3); + RequestContext context = new RequestContext(header, "0", InetAddress.getLocalHost(), + KafkaPrincipal.ANONYMOUS, new ListenerName("ssl"), SecurityProtocol.SASL_SSL, + ClientInformation.EMPTY, true); + + String msg = assertThrows(InvalidRequestException.class, + () -> context.parseRequest(corruptBuffer)).getCause().getMessage(); + assertEquals( + "Tried to allocate a collection of size 2147483647, but there are only 8 bytes remaining.", msg); + } + + private ByteBuffer produceRequest(short version) { + ProduceRequestData data = new ProduceRequestData() + .setAcks((short) -1) + .setTimeoutMs(1); + data.topicData().add( + new ProduceRequestData.TopicProduceData() + .setName("foo") + .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData() + .setIndex(42)))); + + return serialize(version, data); + } + + private ByteBuffer serialize(short version, ApiMessage data) { + ObjectSerializationCache cache = new ObjectSerializationCache(); + data.size(cache, version); + ByteBuffer buffer = ByteBuffer.allocate(1024); + data.write(new ByteBufferAccessor(buffer), cache, version); + buffer.flip(); + return buffer; + } + + @Test + public void testInvalidRequestForByteArray() throws UnknownHostException { + short version = (short) 1; // choose a version with fixed length encoding, for simplicity + ByteBuffer corruptBuffer = serialize(version, new SaslAuthenticateRequestData().setAuthBytes(new byte[0])); + // corrupt the length of the bytes array + corruptBuffer.putInt(0, Integer.MAX_VALUE); + + RequestHeader header = new RequestHeader(ApiKeys.SASL_AUTHENTICATE, version, "console-producer", 1); + RequestContext context = new RequestContext(header, "0", InetAddress.getLocalHost(), + KafkaPrincipal.ANONYMOUS, new ListenerName("ssl"), SecurityProtocol.SASL_SSL, + ClientInformation.EMPTY, true); + + String msg = assertThrows(InvalidRequestException.class, + () -> context.parseRequest(corruptBuffer)).getCause().getMessage(); + assertEquals("Error reading byte array of 2147483647 byte(s): only 0 byte(s) available", msg); + } + } diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java index b6df4c44d63..390cacf4337 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java @@ -210,6 +210,7 @@ import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ByteBufferAccessor; import org.apache.kafka.common.protocol.Errors; import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.types.RawTaggedField; import org.apache.kafka.common.quota.ClientQuotaAlteration; import org.apache.kafka.common.quota.ClientQuotaEntity; import org.apache.kafka.common.quota.ClientQuotaFilter; @@ -231,6 +232,7 @@ import org.apache.kafka.common.security.token.delegation.DelegationToken; import org.apache.kafka.common.security.token.delegation.TokenInformation; import org.apache.kafka.common.utils.SecurityUtils; import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.nio.BufferUnderflowException; @@ -269,6 +271,7 @@ import static org.apache.kafka.common.protocol.ApiKeys.LIST_OFFSETS; import static org.apache.kafka.common.protocol.ApiKeys.METADATA; import static org.apache.kafka.common.protocol.ApiKeys.OFFSET_FETCH; import static org.apache.kafka.common.protocol.ApiKeys.PRODUCE; +import static org.apache.kafka.common.protocol.ApiKeys.SASL_AUTHENTICATE; import static org.apache.kafka.common.protocol.ApiKeys.STOP_REPLICA; import static org.apache.kafka.common.protocol.ApiKeys.SYNC_GROUP; import static org.apache.kafka.common.protocol.ApiKeys.UPDATE_METADATA; @@ -3360,4 +3363,92 @@ public class RequestResponseTest { return new ListTransactionsResponse(response); } + @Test + public void testInvalidSaslHandShakeRequest() { + AbstractRequest request = new SaslHandshakeRequest.Builder( + new SaslHandshakeRequestData().setMechanism("PLAIN")).build(); + ByteBuffer serializedBytes = request.serialize(); + // corrupt the length of the sasl mechanism string + serializedBytes.putShort(0, Short.MAX_VALUE); + + String msg = assertThrows(RuntimeException.class, () -> AbstractRequest. + parseRequest(request.apiKey(), request.version(), serializedBytes)).getMessage(); + assertEquals("Error reading byte array of 32767 byte(s): only 5 byte(s) available", msg); + } + + @Test + public void testInvalidSaslAuthenticateRequest() { + short version = (short) 1; // choose a version with fixed length encoding, for simplicity + byte[] b = new byte[] { + 0x11, 0x1f, 0x15, 0x2c, + 0x5e, 0x2a, 0x20, 0x26, + 0x6c, 0x39, 0x45, 0x1f, + 0x25, 0x1c, 0x2d, 0x25, + 0x43, 0x2a, 0x11, 0x76 + }; + SaslAuthenticateRequestData data = new SaslAuthenticateRequestData().setAuthBytes(b); + AbstractRequest request = new SaslAuthenticateRequest(data, version); + ByteBuffer serializedBytes = request.serialize(); + + // corrupt the length of the bytes array + serializedBytes.putInt(0, Integer.MAX_VALUE); + + String msg = assertThrows(RuntimeException.class, () -> AbstractRequest. + parseRequest(request.apiKey(), request.version(), serializedBytes)).getMessage(); + assertEquals("Error reading byte array of 2147483647 byte(s): only 20 byte(s) available", msg); + } + + @Test + public void testValidTaggedFieldsWithSaslAuthenticateRequest() { + byte[] byteArray = new byte[11]; + ByteBufferAccessor accessor = new ByteBufferAccessor(ByteBuffer.wrap(byteArray)); + + //construct a SASL_AUTHENTICATE request + byte[] authBytes = "test".getBytes(StandardCharsets.UTF_8); + accessor.writeUnsignedVarint(authBytes.length + 1); + accessor.writeByteArray(authBytes); + + //write total numbers of tags + accessor.writeUnsignedVarint(1); + + //write first tag + RawTaggedField taggedField = new RawTaggedField(1, new byte[] {0x1, 0x2, 0x3}); + accessor.writeUnsignedVarint(taggedField.tag()); + accessor.writeUnsignedVarint(taggedField.size()); + accessor.writeByteArray(taggedField.data()); + + accessor.flip(); + + SaslAuthenticateRequest saslAuthenticateRequest = (SaslAuthenticateRequest) AbstractRequest. + parseRequest(SASL_AUTHENTICATE, SASL_AUTHENTICATE.latestVersion(), accessor.buffer()).request; + Assertions.assertArrayEquals(authBytes, saslAuthenticateRequest.data().authBytes()); + assertEquals(1, saslAuthenticateRequest.data().unknownTaggedFields().size()); + assertEquals(taggedField, saslAuthenticateRequest.data().unknownTaggedFields().get(0)); + } + + @Test + public void testInvalidTaggedFieldsWithSaslAuthenticateRequest() { + byte[] byteArray = new byte[13]; + ByteBufferAccessor accessor = new ByteBufferAccessor(ByteBuffer.wrap(byteArray)); + + //construct a SASL_AUTHENTICATE request + byte[] authBytes = "test".getBytes(StandardCharsets.UTF_8); + accessor.writeUnsignedVarint(authBytes.length + 1); + accessor.writeByteArray(authBytes); + + //write total numbers of tags + accessor.writeUnsignedVarint(1); + + //write first tag + RawTaggedField taggedField = new RawTaggedField(1, new byte[] {0x1, 0x2, 0x3}); + accessor.writeUnsignedVarint(taggedField.tag()); + accessor.writeUnsignedVarint(Short.MAX_VALUE); // set wrong size for tagged field + accessor.writeByteArray(taggedField.data()); + + accessor.flip(); + + String msg = assertThrows(RuntimeException.class, () -> AbstractRequest. + parseRequest(SASL_AUTHENTICATE, SASL_AUTHENTICATE.latestVersion(), accessor.buffer())).getMessage(); + assertEquals("Error reading byte array of 32767 byte(s): only 3 byte(s) available", msg); + } } diff --git a/clients/src/test/resources/common/message/SimpleArraysMessage.json b/clients/src/test/resources/common/message/SimpleArraysMessage.json new file mode 100644 index 00000000000..76dc283b6a7 --- /dev/null +++ b/clients/src/test/resources/common/message/SimpleArraysMessage.json @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +{ + "name": "SimpleArraysMessage", + "type": "header", + "validVersions": "0-2", + "flexibleVersions": "1+", + "fields": [ + { "name": "Goats", "type": "[]StructArray", "versions": "1+", + "fields": [ + { "name": "Color", "type": "int8", "versions": "1+"}, + { "name": "Name", "type": "string", "versions": "2+"} + ] + }, + { "name": "Sheep", "type": "[]int32", "versions": "0+" } + ] +} diff --git a/core/src/main/scala/kafka/tools/TestRaftServer.scala b/core/src/main/scala/kafka/tools/TestRaftServer.scala index a72784c469a..ef97b6ccdaa 100644 --- a/core/src/main/scala/kafka/tools/TestRaftServer.scala +++ b/core/src/main/scala/kafka/tools/TestRaftServer.scala @@ -299,11 +299,7 @@ object TestRaftServer extends Logging { out.writeByteArray(data) } - override def read(input: protocol.Readable, size: Int): Array[Byte] = { - val data = new Array[Byte](size) - input.readArray(data) - data - } + override def read(input: protocol.Readable, size: Int): Array[Byte] = input.readArray(size) } private class LatencyHistogram( diff --git a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala index 99b4f817be4..fa9885e3583 100644 --- a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala +++ b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala @@ -1010,11 +1010,7 @@ object KafkaMetadataLogTest { override def write(data: Array[Byte], serializationCache: ObjectSerializationCache, out: Writable): Unit = { out.writeByteArray(data) } - override def read(input: protocol.Readable, size: Int): Array[Byte] = { - val array = new Array[Byte](size) - input.readArray(array) - array - } + override def read(input: protocol.Readable, size: Int): Array[Byte] = input.readArray(size) } val DefaultMetadataLogConfig = MetadataLogConfig( diff --git a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java index 235667480cc..45188fa8109 100644 --- a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java +++ b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java @@ -612,8 +612,7 @@ public final class MessageDataGenerator implements MessageClassGenerator { buffer.printf("%s_readable.readByteBuffer(%s)%s", assignmentPrefix, lengthVar, assignmentSuffix); } else { - buffer.printf("byte[] newBytes = new byte[%s];%n", lengthVar); - buffer.printf("_readable.readArray(newBytes);%n"); + buffer.printf("byte[] newBytes = _readable.readArray(%s);%n", lengthVar); buffer.printf("%snewBytes%s", assignmentPrefix, assignmentSuffix); } } else if (type.isRecords()) { @@ -621,6 +620,12 @@ public final class MessageDataGenerator implements MessageClassGenerator { assignmentPrefix, lengthVar, assignmentSuffix); } else if (type.isArray()) { FieldType.ArrayType arrayType = (FieldType.ArrayType) type; + buffer.printf("if (%s > _readable.remaining()) {%n", lengthVar); + buffer.incrementIndent(); + buffer.printf("throw new RuntimeException(\"Tried to allocate a collection of size \" + %s + \", but " + + "there are only \" + _readable.remaining() + \" bytes remaining.\");%n", lengthVar); + buffer.decrementIndent(); + buffer.printf("}%n"); if (isStructArrayWithKeys) { headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); buffer.printf("%s newCollection = new %s(%s);%n", diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java b/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java index ff415aa72ad..efb1d69a34f 100644 --- a/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java +++ b/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java @@ -16,6 +16,7 @@ */ package org.apache.kafka.raft.internals; +import java.io.DataInputStream; import java.io.IOException; import java.io.UncheckedIOException; import java.nio.ByteBuffer; @@ -25,14 +26,16 @@ import java.util.Iterator; import java.util.List; import java.util.NoSuchElementException; import java.util.Optional; -import org.apache.kafka.common.protocol.DataInputStreamReadable; -import org.apache.kafka.common.protocol.Readable; + +import org.apache.kafka.common.protocol.ByteBufferAccessor; import org.apache.kafka.common.record.DefaultRecordBatch; import org.apache.kafka.common.record.FileRecords; import org.apache.kafka.common.record.MemoryRecords; import org.apache.kafka.common.record.MutableRecordBatch; import org.apache.kafka.common.record.Records; import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.Utils; import org.apache.kafka.raft.Batch; import org.apache.kafka.server.common.serialization.RecordSerde; @@ -53,6 +56,13 @@ public final class RecordsIterator implements Iterator>, AutoCloseab private int bytesRead = 0; private boolean isClosed = false; + /** + * This class provides an iterator over records retrieved via the raft client or from a snapshot + * @param records the records + * @param serde the serde to deserialize records + * @param bufferSupplier the buffer supplier implementation to allocate buffers when reading records. This must return ByteBuffer allocated on the heap + * @param batchSize the maximum batch size + */ public RecordsIterator( Records records, RecordSerde serde, @@ -99,7 +109,7 @@ public final class RecordsIterator implements Iterator>, AutoCloseab private void ensureOpen() { if (isClosed) { - throw new IllegalStateException("Serde record batch itererator was closed"); + throw new IllegalStateException("Serde record batch iterator was closed"); } } @@ -205,11 +215,14 @@ public final class RecordsIterator implements Iterator>, AutoCloseab } List records = new ArrayList<>(numRecords); - try (DataInputStreamReadable input = new DataInputStreamReadable(batch.recordInputStream(bufferSupplier))) { + DataInputStream input = new DataInputStream(batch.recordInputStream(bufferSupplier)); + try { for (int i = 0; i < numRecords; i++) { - T record = readRecord(input); + T record = readRecord(input, batch.sizeInBytes()); records.add(record); } + } finally { + Utils.closeQuietly(input, "DataInputStream"); } result = Batch.data( @@ -224,39 +237,74 @@ public final class RecordsIterator implements Iterator>, AutoCloseab return result; } - private T readRecord(Readable input) { + private T readRecord(DataInputStream stream, int totalBatchSize) { // Read size of body in bytes - input.readVarint(); - - // Read unused attributes - input.readByte(); - - long timestampDelta = input.readVarlong(); - if (timestampDelta != 0) { - throw new IllegalArgumentException(); + int size; + try { + size = ByteUtils.readVarint(stream); + } catch (IOException e) { + throw new UncheckedIOException("Unable to read record size", e); } - - // Read offset delta - input.readVarint(); - - int keySize = input.readVarint(); - if (keySize != -1) { - throw new IllegalArgumentException("Unexpected key size " + keySize); + if (size <= 0) { + throw new RuntimeException("Invalid non-positive frame size: " + size); } - - int valueSize = input.readVarint(); - if (valueSize < 0) { - throw new IllegalArgumentException(); + if (size > totalBatchSize) { + throw new RuntimeException("Specified frame size, " + size + ", is larger than the entire size of the " + + "batch, which is " + totalBatchSize); } + ByteBuffer buf = bufferSupplier.get(size); - // Read the metadata record body from the file input reader - T record = serde.read(input, valueSize); + // The last byte of the buffer is reserved for a varint set to the number of record headers, which + // must be 0. Therefore, we set the ByteBuffer limit to size - 1. + buf.limit(size - 1); - int numHeaders = input.readVarint(); - if (numHeaders != 0) { - throw new IllegalArgumentException(); + try { + int bytesRead = stream.read(buf.array(), 0, size); + if (bytesRead != size) { + throw new RuntimeException("Unable to read " + size + " bytes, only read " + bytesRead); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to read record bytes", e); } + try { + ByteBufferAccessor input = new ByteBufferAccessor(buf); - return record; + // Read unused attributes + input.readByte(); + + long timestampDelta = input.readVarlong(); + if (timestampDelta != 0) { + throw new IllegalArgumentException("Got timestamp delta of " + timestampDelta + ", but this is invalid because it " + + "is not 0 as expected."); + } + + // Read offset delta + input.readVarint(); + + int keySize = input.readVarint(); + if (keySize != -1) { + throw new IllegalArgumentException("Got key size of " + keySize + ", but this is invalid because it " + + "is not -1 as expected."); + } + + int valueSize = input.readVarint(); + if (valueSize < 1) { + throw new IllegalArgumentException("Got payload size of " + valueSize + ", but this is invalid because " + + "it is less than 1."); + } + + // Read the metadata record body from the file input reader + T record = serde.read(input, valueSize); + + // Read the number of headers. Currently, this must be a single byte set to 0. + int numHeaders = buf.array()[size - 1]; + if (numHeaders != 0) { + throw new IllegalArgumentException("Got numHeaders of " + numHeaders + ", but this is invalid because " + + "it is not 0 as expected."); + } + return record; + } finally { + bufferSupplier.release(buf); + } } } diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java b/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java index c2a011a687d..15475be319f 100644 --- a/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java +++ b/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java @@ -40,8 +40,7 @@ public class StringSerde implements RecordSerde { @Override public String read(Readable input, int size) { - byte[] data = new byte[size]; - input.readArray(data); + byte[] data = input.readArray(size); return Utils.utf8(data); }