From d7fc7cf6154592b7fea494a092e42ad9d45b98a0 Mon Sep 17 00:00:00 2001 From: Grant Henke Date: Fri, 12 Feb 2016 18:49:19 -0800 Subject: [PATCH] KAFKA-3088; Make client-id a nullable string and fix handling of invalid requests." MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ent ID - Adds NULLABLE_STRING Type to the protocol - Changes client_id in the REQUEST_HEADER to NULLABLE_STRING with a default of "" - Fixes server handling of invalid ApiKey request and other invalid requests Author: Grant Henke Reviewers: Ismael Juma , Joel Koshy Closes #866 from granthenke/null-clientid --- .../kafka/common/protocol/Protocol.java | 6 +- .../kafka/common/protocol/types/Type.java | 56 ++++++ .../types/ProtocolSerializationTest.java | 15 +- .../common/requests/RequestResponseTest.java | 12 ++ .../network/InvalidRequestException.scala | 10 +- .../scala/kafka/network/RequestChannel.scala | 7 +- .../kafka/server/EdgeCaseRequestTest.scala | 171 ++++++++++++++++++ 7 files changed, 266 insertions(+), 11 deletions(-) create mode 100755 core/src/test/scala/unit/kafka/server/EdgeCaseRequestTest.scala diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java b/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java index ff844e7402b..48c64c206f0 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java @@ -26,6 +26,7 @@ import static org.apache.kafka.common.protocol.types.Type.INT32; import static org.apache.kafka.common.protocol.types.Type.INT64; import static org.apache.kafka.common.protocol.types.Type.INT8; import static org.apache.kafka.common.protocol.types.Type.STRING; +import static org.apache.kafka.common.protocol.types.Type.NULLABLE_STRING; public class Protocol { @@ -35,8 +36,9 @@ public class Protocol { INT32, "A user-supplied integer value that will be passed back with the response"), new Field("client_id", - STRING, - "A user specified identifier for the client making the request.")); + NULLABLE_STRING, + "A user specified identifier for the client making the request.", + "")); public static final Schema RESPONSE_HEADER = new Schema(new Field("correlation_id", INT32, diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/Type.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/Type.java index 04833877979..c4bcb1e4920 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/types/Type.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/Type.java @@ -216,6 +216,62 @@ public abstract class Type { } }; + public static final Type NULLABLE_STRING = new Type() { + @Override + public boolean isNullable() { + return true; + } + + @Override + public void write(ByteBuffer buffer, Object o) { + if (o == null) { + buffer.putShort((short) -1); + return; + } + + byte[] bytes = Utils.utf8((String) o); + if (bytes.length > Short.MAX_VALUE) + throw new SchemaException("String is longer than the maximum string length."); + buffer.putShort((short) bytes.length); + buffer.put(bytes); + } + + @Override + public Object read(ByteBuffer buffer) { + int length = buffer.getShort(); + if (length < 0) + return null; + + byte[] bytes = new byte[length]; + buffer.get(bytes); + return Utils.utf8(bytes); + } + + @Override + public int sizeOf(Object o) { + if (o == null) + return 2; + + return 2 + Utils.utf8Length((String) o); + } + + @Override + public String toString() { + return "NULLABLE_STRING"; + } + + @Override + public String validate(Object item) { + if (item == null) + return null; + + if (item instanceof String) + return (String) item; + else + throw new SchemaException(item + " is not a String."); + } + }; + public static final Type BYTES = new Type() { @Override public void write(ByteBuffer buffer, Object o) { diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/types/ProtocolSerializationTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/types/ProtocolSerializationTest.java index 9fe20c145c3..e20aa109325 100644 --- a/clients/src/test/java/org/apache/kafka/common/protocol/types/ProtocolSerializationTest.java +++ b/clients/src/test/java/org/apache/kafka/common/protocol/types/ProtocolSerializationTest.java @@ -38,6 +38,7 @@ public class ProtocolSerializationTest { new Field("int32", Type.INT32), new Field("int64", Type.INT64), new Field("string", Type.STRING), + new Field("nullable_string", Type.NULLABLE_STRING), new Field("bytes", Type.BYTES), new Field("nullable_bytes", Type.NULLABLE_BYTES), new Field("array", new ArrayOf(Type.INT32)), @@ -47,6 +48,7 @@ public class ProtocolSerializationTest { .set("int32", 1) .set("int64", 1L) .set("string", "1") + .set("nullable_string", null) .set("bytes", ByteBuffer.wrap("1".getBytes())) .set("nullable_bytes", null) .set("array", new Object[] {1}); @@ -62,6 +64,9 @@ public class ProtocolSerializationTest { check(Type.STRING, ""); check(Type.STRING, "hello"); check(Type.STRING, "A\u00ea\u00f1\u00fcC"); + check(Type.NULLABLE_STRING, null); + check(Type.NULLABLE_STRING, ""); + check(Type.NULLABLE_STRING, "hello"); check(Type.BYTES, ByteBuffer.allocate(0)); check(Type.BYTES, ByteBuffer.wrap("abcd".getBytes())); check(Type.NULLABLE_BYTES, null); @@ -99,11 +104,15 @@ public class ProtocolSerializationTest { @Test public void testNullableDefault() { + checkNullableDefault(Type.NULLABLE_BYTES, ByteBuffer.allocate(0)); + checkNullableDefault(Type.NULLABLE_STRING, "default"); + } + + private void checkNullableDefault(Type type, Object defaultValue) { // Should use default even if the field allows null values - ByteBuffer empty = ByteBuffer.allocate(0); - Schema schema = new Schema(new Field("field", Type.NULLABLE_BYTES, "doc", empty)); + Schema schema = new Schema(new Field("field", type, "doc", defaultValue)); Struct struct = new Struct(schema); - assertEquals("Should get the default value", empty, struct.get("field")); + assertEquals("Should get the default value", defaultValue, struct.get("field")); struct.validate(); // should be valid even with missing value } diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java index 789cca79f5e..db9c81a012a 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java @@ -159,6 +159,18 @@ public class RequestResponseTest { assertEquals(response.partitionsRemaining(), deserialized.partitionsRemaining()); } + @Test + public void testRequestHeaderWithNullClientId() { + RequestHeader header = new RequestHeader((short) 10, (short) 1, null, 10); + ByteBuffer buffer = ByteBuffer.allocate(header.sizeOf()); + header.writeTo(buffer); + buffer.rewind(); + RequestHeader deserialized = RequestHeader.parse(buffer); + assertEquals(header.apiKey(), deserialized.apiKey()); + assertEquals(header.apiVersion(), deserialized.apiVersion()); + assertEquals(header.correlationId(), deserialized.correlationId()); + assertEquals("", deserialized.clientId()); // null is defaulted to "" + } private AbstractRequestResponse createRequestHeader() { return new RequestHeader((short) 10, (short) 1, "", 10); diff --git a/core/src/main/scala/kafka/network/InvalidRequestException.scala b/core/src/main/scala/kafka/network/InvalidRequestException.scala index 5197913fd5c..47dba6cced4 100644 --- a/core/src/main/scala/kafka/network/InvalidRequestException.scala +++ b/core/src/main/scala/kafka/network/InvalidRequestException.scala @@ -5,7 +5,7 @@ * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software @@ -17,8 +17,8 @@ package kafka.network -class InvalidRequestException(val message: String) extends RuntimeException(message) { - - def this() = this("") - +class InvalidRequestException(val message: String, cause: Throwable) extends RuntimeException(message, cause) { + + def this() = this("", null) + def this(message: String) = this(message, null) } diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala b/core/src/main/scala/kafka/network/RequestChannel.scala index f0d599ddc85..47f7a34f840 100644 --- a/core/src/main/scala/kafka/network/RequestChannel.scala +++ b/core/src/main/scala/kafka/network/RequestChannel.scala @@ -84,7 +84,12 @@ object RequestChannel extends Logging { null val body: AbstractRequest = if (requestObj == null) - AbstractRequest.getRequest(header.apiKey, header.apiVersion, buffer) + try { + AbstractRequest.getRequest(header.apiKey, header.apiVersion, buffer) + } catch { + case ex: Throwable => + throw new InvalidRequestException(s"Error getting request for apiKey: ${header.apiKey} and apiVersion: ${header.apiVersion}", ex) + } else null diff --git a/core/src/test/scala/unit/kafka/server/EdgeCaseRequestTest.scala b/core/src/test/scala/unit/kafka/server/EdgeCaseRequestTest.scala new file mode 100755 index 00000000000..155eea0f0b7 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/EdgeCaseRequestTest.scala @@ -0,0 +1,171 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket +import java.nio.ByteBuffer + +import kafka.integration.KafkaServerTestHarness + +import kafka.network.SocketServer +import kafka.utils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.types.Type +import org.apache.kafka.common.protocol.{ApiKeys, SecurityProtocol} +import org.apache.kafka.common.requests.{ProduceResponse, ResponseHeader, ProduceRequest} +import org.junit.Assert._ +import org.junit.Test + +import scala.collection.JavaConverters._ + +class EdgeCaseRequestTest extends KafkaServerTestHarness { + + def generateConfigs() = { + val props = TestUtils.createBrokerConfig(1, zkConnect) + props.setProperty(KafkaConfig.AutoCreateTopicsEnableProp, "false") + List(KafkaConfig.fromProps(props)) + } + + private def socketServer = servers.head.socketServer + + private def connect(s: SocketServer = socketServer, protocol: SecurityProtocol = SecurityProtocol.PLAINTEXT): Socket = { + new Socket("localhost", s.boundPort(protocol)) + } + + private def sendRequest(socket: Socket, request: Array[Byte], id: Option[Short] = None) { + val outgoing = new DataOutputStream(socket.getOutputStream) + id match { + case Some(id) => + outgoing.writeInt(request.length + 2) + outgoing.writeShort(id) + case None => + outgoing.writeInt(request.length) + } + outgoing.write(request) + outgoing.flush() + } + + private def receiveResponse(socket: Socket): Array[Byte] = { + val incoming = new DataInputStream(socket.getInputStream) + val len = incoming.readInt() + val response = new Array[Byte](len) + incoming.readFully(response) + response + } + + private def requestAndReceive(request: Array[Byte], id: Option[Short] = None): Array[Byte] = { + val plainSocket = connect() + try { + sendRequest(plainSocket, request, id) + receiveResponse(plainSocket) + } finally { + plainSocket.close() + } + } + + // Custom header serialization so that protocol assumptions are not forced + private def requestHeaderBytes(apiKey: Short, apiVersion: Short, clientId: String = "", correlationId: Int = -1): Array[Byte] = { + val size = { + 2 /* apiKey */ + + 2 /* version id */ + + 4 /* correlation id */ + + Type.NULLABLE_STRING.sizeOf(clientId) /* client id */ + } + + val buffer = ByteBuffer.allocate(size) + buffer.putShort(apiKey) + buffer.putShort(apiVersion) + buffer.putInt(correlationId) + Type.NULLABLE_STRING.write(buffer, clientId) + buffer.array() + } + + private def verifyDisconnect(request: Array[Byte]) { + val plainSocket = connect() + try { + sendRequest(plainSocket, requestHeaderBytes(-1, 0)) + assertEquals("The server should disconnect", -1, plainSocket.getInputStream.read()) + } finally { + plainSocket.close() + } + } + + @Test + def testProduceRequestWithNullClientId() { + val topic = "topic" + val topicPartition = new TopicPartition(topic, 0) + val correlationId = -1 + TestUtils.createTopic(zkUtils, topic, numPartitions = 1, replicationFactor = 1, servers = servers) + + val serializedBytes = { + val headerBytes = requestHeaderBytes(ApiKeys.PRODUCE.id, 1, null, correlationId) + val messageBytes = "message".getBytes + val request = new ProduceRequest(1, 10000, Map(topicPartition -> ByteBuffer.wrap(messageBytes)).asJava) + val byteBuffer = ByteBuffer.allocate(headerBytes.length + request.sizeOf) + byteBuffer.put(headerBytes) + request.writeTo(byteBuffer) + byteBuffer.array() + } + + val response = requestAndReceive(serializedBytes) + + val responseBuffer = ByteBuffer.wrap(response) + val responseHeader = ResponseHeader.parse(responseBuffer) + val produceResponse = ProduceResponse.parse(responseBuffer) + + assertEquals("The response should parse completely", 0, responseBuffer.remaining()) + assertEquals("The correlationId should match request", correlationId, responseHeader.correlationId()) + assertEquals("One partition response should be returned", 1, produceResponse.responses().size()) + + val partitionResponse = produceResponse.responses().get(topicPartition) + assertNotNull(partitionResponse) + assertEquals("There should be no error", 0, partitionResponse.errorCode) + } + + @Test + def testHeaderOnlyRequest() { + verifyDisconnect(requestHeaderBytes(ApiKeys.PRODUCE.id, 1)) + } + + @Test + def testInvalidApiKeyRequest() { + verifyDisconnect(requestHeaderBytes(-1, 0)) + } + + @Test + def testInvalidApiVersionRequest() { + verifyDisconnect(requestHeaderBytes(ApiKeys.PRODUCE.id, -1)) + } + + @Test + def testMalformedHeaderRequest() { + val serializedBytes = { + // Only send apiKey and apiVersion + val buffer = ByteBuffer.allocate( + 2 /* apiKey */ + + 2 /* apiVersion */ + ) + buffer.putShort(ApiKeys.PRODUCE.id) + buffer.putShort(1) + buffer.array() + } + + verifyDisconnect(serializedBytes) + } +}