diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index ef3947ad91a..55fcd1a9e53 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -170,6 +170,10 @@
+
+
+
diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java b/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java
index bd0925d6db3..f643f5b5779 100644
--- a/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java
+++ b/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java
@@ -54,8 +54,15 @@ public class ByteBufferAccessor implements Readable, Writable {
}
@Override
- public void readArray(byte[] arr) {
+ public byte[] readArray(int size) {
+ int remaining = buf.remaining();
+ if (size > remaining) {
+ throw new RuntimeException("Error reading byte array of " + size + " byte(s): only " + remaining +
+ " byte(s) available");
+ }
+ byte[] arr = new byte[size];
buf.get(arr);
+ return arr;
}
@Override
diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/DataInputStreamReadable.java b/clients/src/main/java/org/apache/kafka/common/protocol/DataInputStreamReadable.java
deleted file mode 100644
index 70ed52d6a02..00000000000
--- a/clients/src/main/java/org/apache/kafka/common/protocol/DataInputStreamReadable.java
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.common.protocol;
-
-import org.apache.kafka.common.utils.ByteUtils;
-
-import java.io.Closeable;
-import java.io.DataInputStream;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
-public class DataInputStreamReadable implements Readable, Closeable {
- protected final DataInputStream input;
-
- public DataInputStreamReadable(DataInputStream input) {
- this.input = input;
- }
-
- @Override
- public byte readByte() {
- try {
- return input.readByte();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public short readShort() {
- try {
- return input.readShort();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public int readInt() {
- try {
- return input.readInt();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public long readLong() {
- try {
- return input.readLong();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public double readDouble() {
- try {
- return input.readDouble();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public void readArray(byte[] arr) {
- try {
- input.readFully(arr);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public int readUnsignedVarint() {
- try {
- return ByteUtils.readUnsignedVarint(input);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public ByteBuffer readByteBuffer(int length) {
- byte[] arr = new byte[length];
- readArray(arr);
- return ByteBuffer.wrap(arr);
- }
-
- @Override
- public int readVarint() {
- try {
- return ByteUtils.readVarint(input);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public long readVarlong() {
- try {
- return ByteUtils.readVarlong(input);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public int remaining() {
- try {
- return input.available();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- @Override
- public void close() {
- try {
- input.close();
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
-}
diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java b/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java
index 561696827df..80bee867482 100644
--- a/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java
+++ b/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java
@@ -32,7 +32,7 @@ public interface Readable {
int readInt();
long readLong();
double readDouble();
- void readArray(byte[] arr);
+ byte[] readArray(int length);
int readUnsignedVarint();
ByteBuffer readByteBuffer(int length);
int readVarint();
@@ -40,8 +40,7 @@ public interface Readable {
int remaining();
default String readString(int length) {
- byte[] arr = new byte[length];
- readArray(arr);
+ byte[] arr = readArray(length);
return new String(arr, StandardCharsets.UTF_8);
}
@@ -49,8 +48,7 @@ public interface Readable {
if (unknowns == null) {
unknowns = new ArrayList<>();
}
- byte[] data = new byte[size];
- readArray(data);
+ byte[] data = readArray(size);
unknowns.add(new RawTaggedField(tag, data));
return unknowns;
}
diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
index 8772556b1de..b2235fef490 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java
@@ -342,6 +342,8 @@ public class DefaultRecord implements Record {
int numHeaders = ByteUtils.readVarint(buffer);
if (numHeaders < 0)
throw new InvalidRecordException("Found invalid number of record headers " + numHeaders);
+ if (numHeaders > buffer.remaining())
+ throw new InvalidRecordException("Found invalid number of record headers. " + numHeaders + " is larger than the remaining size of the buffer");
final Header[] headers;
if (numHeaders == 0)
diff --git a/clients/src/test/java/org/apache/kafka/common/message/SimpleArraysMessageTest.java b/clients/src/test/java/org/apache/kafka/common/message/SimpleArraysMessageTest.java
new file mode 100644
index 00000000000..1b78adbb962
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/message/SimpleArraysMessageTest.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.message;
+
+import org.apache.kafka.common.protocol.ByteBufferAccessor;
+import org.junit.jupiter.api.Test;
+
+import java.nio.ByteBuffer;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class SimpleArraysMessageTest {
+ @Test
+ public void testArrayBoundsChecking() {
+ // SimpleArraysMessageData takes 2 arrays
+ final ByteBuffer buf = ByteBuffer.wrap(new byte[] {
+ (byte) 0x7f, // Set size of first array to 126 which is larger than the size of this buffer
+ (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00
+ });
+ final SimpleArraysMessageData out = new SimpleArraysMessageData();
+ ByteBufferAccessor accessor = new ByteBufferAccessor(buf);
+ assertEquals("Tried to allocate a collection of size 126, but there are only 7 bytes remaining.",
+ assertThrows(RuntimeException.class, () -> out.read(accessor, (short) 2)).getMessage());
+ }
+
+ @Test
+ public void testArrayBoundsCheckingOtherArray() {
+ // SimpleArraysMessageData takes 2 arrays
+ final ByteBuffer buf = ByteBuffer.wrap(new byte[] {
+ (byte) 0x01, // Set size of first array to 0
+ (byte) 0x7e, // Set size of second array to 125 which is larger than the size of this buffer
+ (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00
+ });
+ final SimpleArraysMessageData out = new SimpleArraysMessageData();
+ ByteBufferAccessor accessor = new ByteBufferAccessor(buf);
+ assertEquals("Tried to allocate a collection of size 125, but there are only 6 bytes remaining.",
+ assertThrows(RuntimeException.class, () -> out.read(accessor, (short) 2)).getMessage());
+ }
+}
diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/ByteBufferAccessorTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/ByteBufferAccessorTest.java
new file mode 100644
index 00000000000..6a0c6c2681c
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/protocol/ByteBufferAccessorTest.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.protocol;
+
+import org.junit.jupiter.api.Test;
+
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+public class ByteBufferAccessorTest {
+ @Test
+ public void testReadArray() {
+ ByteBuffer buf = ByteBuffer.allocate(1024);
+ ByteBufferAccessor accessor = new ByteBufferAccessor(buf);
+ final byte[] testArray = new byte[] {0x4b, 0x61, 0x46};
+ accessor.writeByteArray(testArray);
+ accessor.writeInt(12345);
+ accessor.flip();
+ final byte[] testArray2 = accessor.readArray(3);
+ assertArrayEquals(testArray, testArray2);
+ assertEquals(12345, accessor.readInt());
+ assertEquals("Error reading byte array of 3 byte(s): only 0 byte(s) available",
+ assertThrows(RuntimeException.class,
+ () -> accessor.readArray(3)).getMessage());
+ }
+
+ @Test
+ public void testReadString() {
+ ByteBuffer buf = ByteBuffer.allocate(1024);
+ ByteBufferAccessor accessor = new ByteBufferAccessor(buf);
+ String testString = "ABC";
+ final byte[] testArray = testString.getBytes(StandardCharsets.UTF_8);
+ accessor.writeByteArray(testArray);
+ accessor.flip();
+ assertEquals("ABC", accessor.readString(3));
+ assertEquals("Error reading byte array of 2 byte(s): only 0 byte(s) available",
+ assertThrows(RuntimeException.class,
+ () -> accessor.readString(2)).getMessage());
+ }
+}
diff --git a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java
index 49743d23201..67212165fc3 100644
--- a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java
@@ -247,6 +247,20 @@ public class DefaultRecordTest {
buf.flip();
assertThrows(InvalidRecordException.class,
() -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null));
+
+ ByteBuffer buf2 = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes));
+ ByteUtils.writeVarint(sizeOfBodyInBytes, buf2);
+ buf2.put(attributes);
+ ByteUtils.writeVarlong(timestampDelta, buf2);
+ ByteUtils.writeVarint(offsetDelta, buf2);
+ ByteUtils.writeVarint(-1, buf2); // null key
+ ByteUtils.writeVarint(-1, buf2); // null value
+ ByteUtils.writeVarint(sizeOfBodyInBytes, buf2); // more headers than remaining buffer size, not allowed
+ buf2.position(buf2.limit());
+
+ buf2.flip();
+ assertThrows(InvalidRecordException.class,
+ () -> DefaultRecord.readFrom(buf2, 0L, 0L, RecordBatch.NO_SEQUENCE, null));
}
@Test
diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java
index 4415ff960aa..254dea0430e 100644
--- a/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java
@@ -16,22 +16,31 @@
*/
package org.apache.kafka.common.requests;
+import org.apache.kafka.common.errors.InvalidRequestException;
import org.apache.kafka.common.message.ApiVersionsResponseData;
import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection;
import org.apache.kafka.common.message.CreateTopicsResponseData;
+import org.apache.kafka.common.message.ProduceRequestData;
+import org.apache.kafka.common.message.SaslAuthenticateRequestData;
import org.apache.kafka.common.network.ClientInformation;
import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.network.Send;
import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.ApiMessage;
+import org.apache.kafka.common.protocol.ByteBufferAccessor;
import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.protocol.ObjectSerializationCache;
import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.junit.jupiter.api.Test;
import java.net.InetAddress;
+import java.net.UnknownHostException;
import java.nio.ByteBuffer;
+import java.util.Collections;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class RequestContextTest {
@@ -104,4 +113,78 @@ public class RequestContextTest {
assertEquals(expectedResponse, parsedResponse.data());
}
+ @Test
+ public void testInvalidRequestForImplicitHashCollection() throws UnknownHostException {
+ short version = (short) 5; // choose a version with fixed length encoding, for simplicity
+ ByteBuffer corruptBuffer = produceRequest(version);
+ // corrupt the length of the topics array
+ corruptBuffer.putInt(8, (Integer.MAX_VALUE - 1) / 2);
+
+ RequestHeader header = new RequestHeader(ApiKeys.PRODUCE, version, "console-producer", 3);
+ RequestContext context = new RequestContext(header, "0", InetAddress.getLocalHost(),
+ KafkaPrincipal.ANONYMOUS, new ListenerName("ssl"), SecurityProtocol.SASL_SSL,
+ ClientInformation.EMPTY, true);
+
+ String msg = assertThrows(InvalidRequestException.class,
+ () -> context.parseRequest(corruptBuffer)).getCause().getMessage();
+ assertEquals("Tried to allocate a collection of size 1073741823, but there are only 17 bytes remaining.", msg);
+ }
+
+ @Test
+ public void testInvalidRequestForArrayList() throws UnknownHostException {
+ short version = (short) 5; // choose a version with fixed length encoding, for simplicity
+ ByteBuffer corruptBuffer = produceRequest(version);
+ // corrupt the length of the partitions array
+ corruptBuffer.putInt(17, Integer.MAX_VALUE);
+
+ RequestHeader header = new RequestHeader(ApiKeys.PRODUCE, version, "console-producer", 3);
+ RequestContext context = new RequestContext(header, "0", InetAddress.getLocalHost(),
+ KafkaPrincipal.ANONYMOUS, new ListenerName("ssl"), SecurityProtocol.SASL_SSL,
+ ClientInformation.EMPTY, true);
+
+ String msg = assertThrows(InvalidRequestException.class,
+ () -> context.parseRequest(corruptBuffer)).getCause().getMessage();
+ assertEquals(
+ "Tried to allocate a collection of size 2147483647, but there are only 8 bytes remaining.", msg);
+ }
+
+ private ByteBuffer produceRequest(short version) {
+ ProduceRequestData data = new ProduceRequestData()
+ .setAcks((short) -1)
+ .setTimeoutMs(1);
+ data.topicData().add(
+ new ProduceRequestData.TopicProduceData()
+ .setName("foo")
+ .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData()
+ .setIndex(42))));
+
+ return serialize(version, data);
+ }
+
+ private ByteBuffer serialize(short version, ApiMessage data) {
+ ObjectSerializationCache cache = new ObjectSerializationCache();
+ data.size(cache, version);
+ ByteBuffer buffer = ByteBuffer.allocate(1024);
+ data.write(new ByteBufferAccessor(buffer), cache, version);
+ buffer.flip();
+ return buffer;
+ }
+
+ @Test
+ public void testInvalidRequestForByteArray() throws UnknownHostException {
+ short version = (short) 1; // choose a version with fixed length encoding, for simplicity
+ ByteBuffer corruptBuffer = serialize(version, new SaslAuthenticateRequestData().setAuthBytes(new byte[0]));
+ // corrupt the length of the bytes array
+ corruptBuffer.putInt(0, Integer.MAX_VALUE);
+
+ RequestHeader header = new RequestHeader(ApiKeys.SASL_AUTHENTICATE, version, "console-producer", 1);
+ RequestContext context = new RequestContext(header, "0", InetAddress.getLocalHost(),
+ KafkaPrincipal.ANONYMOUS, new ListenerName("ssl"), SecurityProtocol.SASL_SSL,
+ ClientInformation.EMPTY, true);
+
+ String msg = assertThrows(InvalidRequestException.class,
+ () -> context.parseRequest(corruptBuffer)).getCause().getMessage();
+ assertEquals("Error reading byte array of 2147483647 byte(s): only 0 byte(s) available", msg);
+ }
+
}
diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
index b6df4c44d63..390cacf4337 100644
--- a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
@@ -210,6 +210,7 @@ import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ByteBufferAccessor;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.protocol.ObjectSerializationCache;
+import org.apache.kafka.common.protocol.types.RawTaggedField;
import org.apache.kafka.common.quota.ClientQuotaAlteration;
import org.apache.kafka.common.quota.ClientQuotaEntity;
import org.apache.kafka.common.quota.ClientQuotaFilter;
@@ -231,6 +232,7 @@ import org.apache.kafka.common.security.token.delegation.DelegationToken;
import org.apache.kafka.common.security.token.delegation.TokenInformation;
import org.apache.kafka.common.utils.SecurityUtils;
import org.apache.kafka.common.utils.Utils;
+import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.nio.BufferUnderflowException;
@@ -269,6 +271,7 @@ import static org.apache.kafka.common.protocol.ApiKeys.LIST_OFFSETS;
import static org.apache.kafka.common.protocol.ApiKeys.METADATA;
import static org.apache.kafka.common.protocol.ApiKeys.OFFSET_FETCH;
import static org.apache.kafka.common.protocol.ApiKeys.PRODUCE;
+import static org.apache.kafka.common.protocol.ApiKeys.SASL_AUTHENTICATE;
import static org.apache.kafka.common.protocol.ApiKeys.STOP_REPLICA;
import static org.apache.kafka.common.protocol.ApiKeys.SYNC_GROUP;
import static org.apache.kafka.common.protocol.ApiKeys.UPDATE_METADATA;
@@ -3360,4 +3363,92 @@ public class RequestResponseTest {
return new ListTransactionsResponse(response);
}
+ @Test
+ public void testInvalidSaslHandShakeRequest() {
+ AbstractRequest request = new SaslHandshakeRequest.Builder(
+ new SaslHandshakeRequestData().setMechanism("PLAIN")).build();
+ ByteBuffer serializedBytes = request.serialize();
+ // corrupt the length of the sasl mechanism string
+ serializedBytes.putShort(0, Short.MAX_VALUE);
+
+ String msg = assertThrows(RuntimeException.class, () -> AbstractRequest.
+ parseRequest(request.apiKey(), request.version(), serializedBytes)).getMessage();
+ assertEquals("Error reading byte array of 32767 byte(s): only 5 byte(s) available", msg);
+ }
+
+ @Test
+ public void testInvalidSaslAuthenticateRequest() {
+ short version = (short) 1; // choose a version with fixed length encoding, for simplicity
+ byte[] b = new byte[] {
+ 0x11, 0x1f, 0x15, 0x2c,
+ 0x5e, 0x2a, 0x20, 0x26,
+ 0x6c, 0x39, 0x45, 0x1f,
+ 0x25, 0x1c, 0x2d, 0x25,
+ 0x43, 0x2a, 0x11, 0x76
+ };
+ SaslAuthenticateRequestData data = new SaslAuthenticateRequestData().setAuthBytes(b);
+ AbstractRequest request = new SaslAuthenticateRequest(data, version);
+ ByteBuffer serializedBytes = request.serialize();
+
+ // corrupt the length of the bytes array
+ serializedBytes.putInt(0, Integer.MAX_VALUE);
+
+ String msg = assertThrows(RuntimeException.class, () -> AbstractRequest.
+ parseRequest(request.apiKey(), request.version(), serializedBytes)).getMessage();
+ assertEquals("Error reading byte array of 2147483647 byte(s): only 20 byte(s) available", msg);
+ }
+
+ @Test
+ public void testValidTaggedFieldsWithSaslAuthenticateRequest() {
+ byte[] byteArray = new byte[11];
+ ByteBufferAccessor accessor = new ByteBufferAccessor(ByteBuffer.wrap(byteArray));
+
+ //construct a SASL_AUTHENTICATE request
+ byte[] authBytes = "test".getBytes(StandardCharsets.UTF_8);
+ accessor.writeUnsignedVarint(authBytes.length + 1);
+ accessor.writeByteArray(authBytes);
+
+ //write total numbers of tags
+ accessor.writeUnsignedVarint(1);
+
+ //write first tag
+ RawTaggedField taggedField = new RawTaggedField(1, new byte[] {0x1, 0x2, 0x3});
+ accessor.writeUnsignedVarint(taggedField.tag());
+ accessor.writeUnsignedVarint(taggedField.size());
+ accessor.writeByteArray(taggedField.data());
+
+ accessor.flip();
+
+ SaslAuthenticateRequest saslAuthenticateRequest = (SaslAuthenticateRequest) AbstractRequest.
+ parseRequest(SASL_AUTHENTICATE, SASL_AUTHENTICATE.latestVersion(), accessor.buffer()).request;
+ Assertions.assertArrayEquals(authBytes, saslAuthenticateRequest.data().authBytes());
+ assertEquals(1, saslAuthenticateRequest.data().unknownTaggedFields().size());
+ assertEquals(taggedField, saslAuthenticateRequest.data().unknownTaggedFields().get(0));
+ }
+
+ @Test
+ public void testInvalidTaggedFieldsWithSaslAuthenticateRequest() {
+ byte[] byteArray = new byte[13];
+ ByteBufferAccessor accessor = new ByteBufferAccessor(ByteBuffer.wrap(byteArray));
+
+ //construct a SASL_AUTHENTICATE request
+ byte[] authBytes = "test".getBytes(StandardCharsets.UTF_8);
+ accessor.writeUnsignedVarint(authBytes.length + 1);
+ accessor.writeByteArray(authBytes);
+
+ //write total numbers of tags
+ accessor.writeUnsignedVarint(1);
+
+ //write first tag
+ RawTaggedField taggedField = new RawTaggedField(1, new byte[] {0x1, 0x2, 0x3});
+ accessor.writeUnsignedVarint(taggedField.tag());
+ accessor.writeUnsignedVarint(Short.MAX_VALUE); // set wrong size for tagged field
+ accessor.writeByteArray(taggedField.data());
+
+ accessor.flip();
+
+ String msg = assertThrows(RuntimeException.class, () -> AbstractRequest.
+ parseRequest(SASL_AUTHENTICATE, SASL_AUTHENTICATE.latestVersion(), accessor.buffer())).getMessage();
+ assertEquals("Error reading byte array of 32767 byte(s): only 3 byte(s) available", msg);
+ }
}
diff --git a/clients/src/test/resources/common/message/SimpleArraysMessage.json b/clients/src/test/resources/common/message/SimpleArraysMessage.json
new file mode 100644
index 00000000000..76dc283b6a7
--- /dev/null
+++ b/clients/src/test/resources/common/message/SimpleArraysMessage.json
@@ -0,0 +1,29 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+{
+ "name": "SimpleArraysMessage",
+ "type": "header",
+ "validVersions": "0-2",
+ "flexibleVersions": "1+",
+ "fields": [
+ { "name": "Goats", "type": "[]StructArray", "versions": "1+",
+ "fields": [
+ { "name": "Color", "type": "int8", "versions": "1+"},
+ { "name": "Name", "type": "string", "versions": "2+"}
+ ]
+ },
+ { "name": "Sheep", "type": "[]int32", "versions": "0+" }
+ ]
+}
diff --git a/core/src/main/scala/kafka/tools/TestRaftServer.scala b/core/src/main/scala/kafka/tools/TestRaftServer.scala
index a72784c469a..ef97b6ccdaa 100644
--- a/core/src/main/scala/kafka/tools/TestRaftServer.scala
+++ b/core/src/main/scala/kafka/tools/TestRaftServer.scala
@@ -299,11 +299,7 @@ object TestRaftServer extends Logging {
out.writeByteArray(data)
}
- override def read(input: protocol.Readable, size: Int): Array[Byte] = {
- val data = new Array[Byte](size)
- input.readArray(data)
- data
- }
+ override def read(input: protocol.Readable, size: Int): Array[Byte] = input.readArray(size)
}
private class LatencyHistogram(
diff --git a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
index 99b4f817be4..fa9885e3583 100644
--- a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
+++ b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala
@@ -1010,11 +1010,7 @@ object KafkaMetadataLogTest {
override def write(data: Array[Byte], serializationCache: ObjectSerializationCache, out: Writable): Unit = {
out.writeByteArray(data)
}
- override def read(input: protocol.Readable, size: Int): Array[Byte] = {
- val array = new Array[Byte](size)
- input.readArray(array)
- array
- }
+ override def read(input: protocol.Readable, size: Int): Array[Byte] = input.readArray(size)
}
val DefaultMetadataLogConfig = MetadataLogConfig(
diff --git a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java
index 235667480cc..45188fa8109 100644
--- a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java
+++ b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java
@@ -612,8 +612,7 @@ public final class MessageDataGenerator implements MessageClassGenerator {
buffer.printf("%s_readable.readByteBuffer(%s)%s",
assignmentPrefix, lengthVar, assignmentSuffix);
} else {
- buffer.printf("byte[] newBytes = new byte[%s];%n", lengthVar);
- buffer.printf("_readable.readArray(newBytes);%n");
+ buffer.printf("byte[] newBytes = _readable.readArray(%s);%n", lengthVar);
buffer.printf("%snewBytes%s", assignmentPrefix, assignmentSuffix);
}
} else if (type.isRecords()) {
@@ -621,6 +620,12 @@ public final class MessageDataGenerator implements MessageClassGenerator {
assignmentPrefix, lengthVar, assignmentSuffix);
} else if (type.isArray()) {
FieldType.ArrayType arrayType = (FieldType.ArrayType) type;
+ buffer.printf("if (%s > _readable.remaining()) {%n", lengthVar);
+ buffer.incrementIndent();
+ buffer.printf("throw new RuntimeException(\"Tried to allocate a collection of size \" + %s + \", but " +
+ "there are only \" + _readable.remaining() + \" bytes remaining.\");%n", lengthVar);
+ buffer.decrementIndent();
+ buffer.printf("}%n");
if (isStructArrayWithKeys) {
headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS);
buffer.printf("%s newCollection = new %s(%s);%n",
diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java b/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java
index ff415aa72ad..efb1d69a34f 100644
--- a/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java
@@ -16,6 +16,7 @@
*/
package org.apache.kafka.raft.internals;
+import java.io.DataInputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
@@ -25,14 +26,16 @@ import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Optional;
-import org.apache.kafka.common.protocol.DataInputStreamReadable;
-import org.apache.kafka.common.protocol.Readable;
+
+import org.apache.kafka.common.protocol.ByteBufferAccessor;
import org.apache.kafka.common.record.DefaultRecordBatch;
import org.apache.kafka.common.record.FileRecords;
import org.apache.kafka.common.record.MemoryRecords;
import org.apache.kafka.common.record.MutableRecordBatch;
import org.apache.kafka.common.record.Records;
import org.apache.kafka.common.utils.BufferSupplier;
+import org.apache.kafka.common.utils.ByteUtils;
+import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.raft.Batch;
import org.apache.kafka.server.common.serialization.RecordSerde;
@@ -53,6 +56,13 @@ public final class RecordsIterator implements Iterator>, AutoCloseab
private int bytesRead = 0;
private boolean isClosed = false;
+ /**
+ * This class provides an iterator over records retrieved via the raft client or from a snapshot
+ * @param records the records
+ * @param serde the serde to deserialize records
+ * @param bufferSupplier the buffer supplier implementation to allocate buffers when reading records. This must return ByteBuffer allocated on the heap
+ * @param batchSize the maximum batch size
+ */
public RecordsIterator(
Records records,
RecordSerde serde,
@@ -99,7 +109,7 @@ public final class RecordsIterator implements Iterator>, AutoCloseab
private void ensureOpen() {
if (isClosed) {
- throw new IllegalStateException("Serde record batch itererator was closed");
+ throw new IllegalStateException("Serde record batch iterator was closed");
}
}
@@ -205,11 +215,14 @@ public final class RecordsIterator implements Iterator>, AutoCloseab
}
List records = new ArrayList<>(numRecords);
- try (DataInputStreamReadable input = new DataInputStreamReadable(batch.recordInputStream(bufferSupplier))) {
+ DataInputStream input = new DataInputStream(batch.recordInputStream(bufferSupplier));
+ try {
for (int i = 0; i < numRecords; i++) {
- T record = readRecord(input);
+ T record = readRecord(input, batch.sizeInBytes());
records.add(record);
}
+ } finally {
+ Utils.closeQuietly(input, "DataInputStream");
}
result = Batch.data(
@@ -224,39 +237,74 @@ public final class RecordsIterator implements Iterator>, AutoCloseab
return result;
}
- private T readRecord(Readable input) {
+ private T readRecord(DataInputStream stream, int totalBatchSize) {
// Read size of body in bytes
- input.readVarint();
-
- // Read unused attributes
- input.readByte();
-
- long timestampDelta = input.readVarlong();
- if (timestampDelta != 0) {
- throw new IllegalArgumentException();
+ int size;
+ try {
+ size = ByteUtils.readVarint(stream);
+ } catch (IOException e) {
+ throw new UncheckedIOException("Unable to read record size", e);
}
-
- // Read offset delta
- input.readVarint();
-
- int keySize = input.readVarint();
- if (keySize != -1) {
- throw new IllegalArgumentException("Unexpected key size " + keySize);
+ if (size <= 0) {
+ throw new RuntimeException("Invalid non-positive frame size: " + size);
}
-
- int valueSize = input.readVarint();
- if (valueSize < 0) {
- throw new IllegalArgumentException();
+ if (size > totalBatchSize) {
+ throw new RuntimeException("Specified frame size, " + size + ", is larger than the entire size of the " +
+ "batch, which is " + totalBatchSize);
}
+ ByteBuffer buf = bufferSupplier.get(size);
- // Read the metadata record body from the file input reader
- T record = serde.read(input, valueSize);
+ // The last byte of the buffer is reserved for a varint set to the number of record headers, which
+ // must be 0. Therefore, we set the ByteBuffer limit to size - 1.
+ buf.limit(size - 1);
- int numHeaders = input.readVarint();
- if (numHeaders != 0) {
- throw new IllegalArgumentException();
+ try {
+ int bytesRead = stream.read(buf.array(), 0, size);
+ if (bytesRead != size) {
+ throw new RuntimeException("Unable to read " + size + " bytes, only read " + bytesRead);
+ }
+ } catch (IOException e) {
+ throw new UncheckedIOException("Failed to read record bytes", e);
}
+ try {
+ ByteBufferAccessor input = new ByteBufferAccessor(buf);
- return record;
+ // Read unused attributes
+ input.readByte();
+
+ long timestampDelta = input.readVarlong();
+ if (timestampDelta != 0) {
+ throw new IllegalArgumentException("Got timestamp delta of " + timestampDelta + ", but this is invalid because it " +
+ "is not 0 as expected.");
+ }
+
+ // Read offset delta
+ input.readVarint();
+
+ int keySize = input.readVarint();
+ if (keySize != -1) {
+ throw new IllegalArgumentException("Got key size of " + keySize + ", but this is invalid because it " +
+ "is not -1 as expected.");
+ }
+
+ int valueSize = input.readVarint();
+ if (valueSize < 1) {
+ throw new IllegalArgumentException("Got payload size of " + valueSize + ", but this is invalid because " +
+ "it is less than 1.");
+ }
+
+ // Read the metadata record body from the file input reader
+ T record = serde.read(input, valueSize);
+
+ // Read the number of headers. Currently, this must be a single byte set to 0.
+ int numHeaders = buf.array()[size - 1];
+ if (numHeaders != 0) {
+ throw new IllegalArgumentException("Got numHeaders of " + numHeaders + ", but this is invalid because " +
+ "it is not 0 as expected.");
+ }
+ return record;
+ } finally {
+ bufferSupplier.release(buf);
+ }
}
}
diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java b/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java
index c2a011a687d..15475be319f 100644
--- a/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java
@@ -40,8 +40,7 @@ public class StringSerde implements RecordSerde {
@Override
public String read(Readable input, int size) {
- byte[] data = new byte[size];
- input.readArray(data);
+ byte[] data = input.readArray(size);
return Utils.utf8(data);
}