mirror of https://github.com/apache/kafka.git
KAFKA-3088; Make client-id a nullable string and fix handling of invalid requests."
…ent ID - Adds NULLABLE_STRING Type to the protocol - Changes client_id in the REQUEST_HEADER to NULLABLE_STRING with a default of "" - Fixes server handling of invalid ApiKey request and other invalid requests Author: Grant Henke <granthenke@gmail.com> Reviewers: Ismael Juma <ismael@juma.me.uk>, Joel Koshy <jjkoshy.w@gmail.com> Closes #866 from granthenke/null-clientid
This commit is contained in:
parent
85599bc3e8
commit
d7fc7cf615
|
@ -26,6 +26,7 @@ import static org.apache.kafka.common.protocol.types.Type.INT32;
|
||||||
import static org.apache.kafka.common.protocol.types.Type.INT64;
|
import static org.apache.kafka.common.protocol.types.Type.INT64;
|
||||||
import static org.apache.kafka.common.protocol.types.Type.INT8;
|
import static org.apache.kafka.common.protocol.types.Type.INT8;
|
||||||
import static org.apache.kafka.common.protocol.types.Type.STRING;
|
import static org.apache.kafka.common.protocol.types.Type.STRING;
|
||||||
|
import static org.apache.kafka.common.protocol.types.Type.NULLABLE_STRING;
|
||||||
|
|
||||||
public class Protocol {
|
public class Protocol {
|
||||||
|
|
||||||
|
@ -35,8 +36,9 @@ public class Protocol {
|
||||||
INT32,
|
INT32,
|
||||||
"A user-supplied integer value that will be passed back with the response"),
|
"A user-supplied integer value that will be passed back with the response"),
|
||||||
new Field("client_id",
|
new Field("client_id",
|
||||||
STRING,
|
NULLABLE_STRING,
|
||||||
"A user specified identifier for the client making the request."));
|
"A user specified identifier for the client making the request.",
|
||||||
|
""));
|
||||||
|
|
||||||
public static final Schema RESPONSE_HEADER = new Schema(new Field("correlation_id",
|
public static final Schema RESPONSE_HEADER = new Schema(new Field("correlation_id",
|
||||||
INT32,
|
INT32,
|
||||||
|
|
|
@ -216,6 +216,62 @@ public abstract class Type {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
public static final Type NULLABLE_STRING = new Type() {
|
||||||
|
@Override
|
||||||
|
public boolean isNullable() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(ByteBuffer buffer, Object o) {
|
||||||
|
if (o == null) {
|
||||||
|
buffer.putShort((short) -1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
byte[] bytes = Utils.utf8((String) o);
|
||||||
|
if (bytes.length > Short.MAX_VALUE)
|
||||||
|
throw new SchemaException("String is longer than the maximum string length.");
|
||||||
|
buffer.putShort((short) bytes.length);
|
||||||
|
buffer.put(bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object read(ByteBuffer buffer) {
|
||||||
|
int length = buffer.getShort();
|
||||||
|
if (length < 0)
|
||||||
|
return null;
|
||||||
|
|
||||||
|
byte[] bytes = new byte[length];
|
||||||
|
buffer.get(bytes);
|
||||||
|
return Utils.utf8(bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int sizeOf(Object o) {
|
||||||
|
if (o == null)
|
||||||
|
return 2;
|
||||||
|
|
||||||
|
return 2 + Utils.utf8Length((String) o);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "NULLABLE_STRING";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String validate(Object item) {
|
||||||
|
if (item == null)
|
||||||
|
return null;
|
||||||
|
|
||||||
|
if (item instanceof String)
|
||||||
|
return (String) item;
|
||||||
|
else
|
||||||
|
throw new SchemaException(item + " is not a String.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
public static final Type BYTES = new Type() {
|
public static final Type BYTES = new Type() {
|
||||||
@Override
|
@Override
|
||||||
public void write(ByteBuffer buffer, Object o) {
|
public void write(ByteBuffer buffer, Object o) {
|
||||||
|
|
|
@ -38,6 +38,7 @@ public class ProtocolSerializationTest {
|
||||||
new Field("int32", Type.INT32),
|
new Field("int32", Type.INT32),
|
||||||
new Field("int64", Type.INT64),
|
new Field("int64", Type.INT64),
|
||||||
new Field("string", Type.STRING),
|
new Field("string", Type.STRING),
|
||||||
|
new Field("nullable_string", Type.NULLABLE_STRING),
|
||||||
new Field("bytes", Type.BYTES),
|
new Field("bytes", Type.BYTES),
|
||||||
new Field("nullable_bytes", Type.NULLABLE_BYTES),
|
new Field("nullable_bytes", Type.NULLABLE_BYTES),
|
||||||
new Field("array", new ArrayOf(Type.INT32)),
|
new Field("array", new ArrayOf(Type.INT32)),
|
||||||
|
@ -47,6 +48,7 @@ public class ProtocolSerializationTest {
|
||||||
.set("int32", 1)
|
.set("int32", 1)
|
||||||
.set("int64", 1L)
|
.set("int64", 1L)
|
||||||
.set("string", "1")
|
.set("string", "1")
|
||||||
|
.set("nullable_string", null)
|
||||||
.set("bytes", ByteBuffer.wrap("1".getBytes()))
|
.set("bytes", ByteBuffer.wrap("1".getBytes()))
|
||||||
.set("nullable_bytes", null)
|
.set("nullable_bytes", null)
|
||||||
.set("array", new Object[] {1});
|
.set("array", new Object[] {1});
|
||||||
|
@ -62,6 +64,9 @@ public class ProtocolSerializationTest {
|
||||||
check(Type.STRING, "");
|
check(Type.STRING, "");
|
||||||
check(Type.STRING, "hello");
|
check(Type.STRING, "hello");
|
||||||
check(Type.STRING, "A\u00ea\u00f1\u00fcC");
|
check(Type.STRING, "A\u00ea\u00f1\u00fcC");
|
||||||
|
check(Type.NULLABLE_STRING, null);
|
||||||
|
check(Type.NULLABLE_STRING, "");
|
||||||
|
check(Type.NULLABLE_STRING, "hello");
|
||||||
check(Type.BYTES, ByteBuffer.allocate(0));
|
check(Type.BYTES, ByteBuffer.allocate(0));
|
||||||
check(Type.BYTES, ByteBuffer.wrap("abcd".getBytes()));
|
check(Type.BYTES, ByteBuffer.wrap("abcd".getBytes()));
|
||||||
check(Type.NULLABLE_BYTES, null);
|
check(Type.NULLABLE_BYTES, null);
|
||||||
|
@ -99,11 +104,15 @@ public class ProtocolSerializationTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNullableDefault() {
|
public void testNullableDefault() {
|
||||||
|
checkNullableDefault(Type.NULLABLE_BYTES, ByteBuffer.allocate(0));
|
||||||
|
checkNullableDefault(Type.NULLABLE_STRING, "default");
|
||||||
|
}
|
||||||
|
|
||||||
|
private void checkNullableDefault(Type type, Object defaultValue) {
|
||||||
// Should use default even if the field allows null values
|
// Should use default even if the field allows null values
|
||||||
ByteBuffer empty = ByteBuffer.allocate(0);
|
Schema schema = new Schema(new Field("field", type, "doc", defaultValue));
|
||||||
Schema schema = new Schema(new Field("field", Type.NULLABLE_BYTES, "doc", empty));
|
|
||||||
Struct struct = new Struct(schema);
|
Struct struct = new Struct(schema);
|
||||||
assertEquals("Should get the default value", empty, struct.get("field"));
|
assertEquals("Should get the default value", defaultValue, struct.get("field"));
|
||||||
struct.validate(); // should be valid even with missing value
|
struct.validate(); // should be valid even with missing value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -159,6 +159,18 @@ public class RequestResponseTest {
|
||||||
assertEquals(response.partitionsRemaining(), deserialized.partitionsRemaining());
|
assertEquals(response.partitionsRemaining(), deserialized.partitionsRemaining());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRequestHeaderWithNullClientId() {
|
||||||
|
RequestHeader header = new RequestHeader((short) 10, (short) 1, null, 10);
|
||||||
|
ByteBuffer buffer = ByteBuffer.allocate(header.sizeOf());
|
||||||
|
header.writeTo(buffer);
|
||||||
|
buffer.rewind();
|
||||||
|
RequestHeader deserialized = RequestHeader.parse(buffer);
|
||||||
|
assertEquals(header.apiKey(), deserialized.apiKey());
|
||||||
|
assertEquals(header.apiVersion(), deserialized.apiVersion());
|
||||||
|
assertEquals(header.correlationId(), deserialized.correlationId());
|
||||||
|
assertEquals("", deserialized.clientId()); // null is defaulted to ""
|
||||||
|
}
|
||||||
|
|
||||||
private AbstractRequestResponse createRequestHeader() {
|
private AbstractRequestResponse createRequestHeader() {
|
||||||
return new RequestHeader((short) 10, (short) 1, "", 10);
|
return new RequestHeader((short) 10, (short) 1, "", 10);
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
|
|
||||||
package kafka.network
|
package kafka.network
|
||||||
|
|
||||||
class InvalidRequestException(val message: String) extends RuntimeException(message) {
|
class InvalidRequestException(val message: String, cause: Throwable) extends RuntimeException(message, cause) {
|
||||||
|
|
||||||
def this() = this("")
|
|
||||||
|
|
||||||
|
def this() = this("", null)
|
||||||
|
def this(message: String) = this(message, null)
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,7 +84,12 @@ object RequestChannel extends Logging {
|
||||||
null
|
null
|
||||||
val body: AbstractRequest =
|
val body: AbstractRequest =
|
||||||
if (requestObj == null)
|
if (requestObj == null)
|
||||||
|
try {
|
||||||
AbstractRequest.getRequest(header.apiKey, header.apiVersion, buffer)
|
AbstractRequest.getRequest(header.apiKey, header.apiVersion, buffer)
|
||||||
|
} catch {
|
||||||
|
case ex: Throwable =>
|
||||||
|
throw new InvalidRequestException(s"Error getting request for apiKey: ${header.apiKey} and apiVersion: ${header.apiVersion}", ex)
|
||||||
|
}
|
||||||
else
|
else
|
||||||
null
|
null
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,171 @@
|
||||||
|
/**
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package kafka.server
|
||||||
|
|
||||||
|
import java.io.{DataInputStream, DataOutputStream}
|
||||||
|
import java.net.Socket
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
|
||||||
|
import kafka.integration.KafkaServerTestHarness
|
||||||
|
|
||||||
|
import kafka.network.SocketServer
|
||||||
|
import kafka.utils._
|
||||||
|
import org.apache.kafka.common.TopicPartition
|
||||||
|
import org.apache.kafka.common.protocol.types.Type
|
||||||
|
import org.apache.kafka.common.protocol.{ApiKeys, SecurityProtocol}
|
||||||
|
import org.apache.kafka.common.requests.{ProduceResponse, ResponseHeader, ProduceRequest}
|
||||||
|
import org.junit.Assert._
|
||||||
|
import org.junit.Test
|
||||||
|
|
||||||
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
class EdgeCaseRequestTest extends KafkaServerTestHarness {
|
||||||
|
|
||||||
|
def generateConfigs() = {
|
||||||
|
val props = TestUtils.createBrokerConfig(1, zkConnect)
|
||||||
|
props.setProperty(KafkaConfig.AutoCreateTopicsEnableProp, "false")
|
||||||
|
List(KafkaConfig.fromProps(props))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def socketServer = servers.head.socketServer
|
||||||
|
|
||||||
|
private def connect(s: SocketServer = socketServer, protocol: SecurityProtocol = SecurityProtocol.PLAINTEXT): Socket = {
|
||||||
|
new Socket("localhost", s.boundPort(protocol))
|
||||||
|
}
|
||||||
|
|
||||||
|
private def sendRequest(socket: Socket, request: Array[Byte], id: Option[Short] = None) {
|
||||||
|
val outgoing = new DataOutputStream(socket.getOutputStream)
|
||||||
|
id match {
|
||||||
|
case Some(id) =>
|
||||||
|
outgoing.writeInt(request.length + 2)
|
||||||
|
outgoing.writeShort(id)
|
||||||
|
case None =>
|
||||||
|
outgoing.writeInt(request.length)
|
||||||
|
}
|
||||||
|
outgoing.write(request)
|
||||||
|
outgoing.flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
private def receiveResponse(socket: Socket): Array[Byte] = {
|
||||||
|
val incoming = new DataInputStream(socket.getInputStream)
|
||||||
|
val len = incoming.readInt()
|
||||||
|
val response = new Array[Byte](len)
|
||||||
|
incoming.readFully(response)
|
||||||
|
response
|
||||||
|
}
|
||||||
|
|
||||||
|
private def requestAndReceive(request: Array[Byte], id: Option[Short] = None): Array[Byte] = {
|
||||||
|
val plainSocket = connect()
|
||||||
|
try {
|
||||||
|
sendRequest(plainSocket, request, id)
|
||||||
|
receiveResponse(plainSocket)
|
||||||
|
} finally {
|
||||||
|
plainSocket.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom header serialization so that protocol assumptions are not forced
|
||||||
|
private def requestHeaderBytes(apiKey: Short, apiVersion: Short, clientId: String = "", correlationId: Int = -1): Array[Byte] = {
|
||||||
|
val size = {
|
||||||
|
2 /* apiKey */ +
|
||||||
|
2 /* version id */ +
|
||||||
|
4 /* correlation id */ +
|
||||||
|
Type.NULLABLE_STRING.sizeOf(clientId) /* client id */
|
||||||
|
}
|
||||||
|
|
||||||
|
val buffer = ByteBuffer.allocate(size)
|
||||||
|
buffer.putShort(apiKey)
|
||||||
|
buffer.putShort(apiVersion)
|
||||||
|
buffer.putInt(correlationId)
|
||||||
|
Type.NULLABLE_STRING.write(buffer, clientId)
|
||||||
|
buffer.array()
|
||||||
|
}
|
||||||
|
|
||||||
|
private def verifyDisconnect(request: Array[Byte]) {
|
||||||
|
val plainSocket = connect()
|
||||||
|
try {
|
||||||
|
sendRequest(plainSocket, requestHeaderBytes(-1, 0))
|
||||||
|
assertEquals("The server should disconnect", -1, plainSocket.getInputStream.read())
|
||||||
|
} finally {
|
||||||
|
plainSocket.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
def testProduceRequestWithNullClientId() {
|
||||||
|
val topic = "topic"
|
||||||
|
val topicPartition = new TopicPartition(topic, 0)
|
||||||
|
val correlationId = -1
|
||||||
|
TestUtils.createTopic(zkUtils, topic, numPartitions = 1, replicationFactor = 1, servers = servers)
|
||||||
|
|
||||||
|
val serializedBytes = {
|
||||||
|
val headerBytes = requestHeaderBytes(ApiKeys.PRODUCE.id, 1, null, correlationId)
|
||||||
|
val messageBytes = "message".getBytes
|
||||||
|
val request = new ProduceRequest(1, 10000, Map(topicPartition -> ByteBuffer.wrap(messageBytes)).asJava)
|
||||||
|
val byteBuffer = ByteBuffer.allocate(headerBytes.length + request.sizeOf)
|
||||||
|
byteBuffer.put(headerBytes)
|
||||||
|
request.writeTo(byteBuffer)
|
||||||
|
byteBuffer.array()
|
||||||
|
}
|
||||||
|
|
||||||
|
val response = requestAndReceive(serializedBytes)
|
||||||
|
|
||||||
|
val responseBuffer = ByteBuffer.wrap(response)
|
||||||
|
val responseHeader = ResponseHeader.parse(responseBuffer)
|
||||||
|
val produceResponse = ProduceResponse.parse(responseBuffer)
|
||||||
|
|
||||||
|
assertEquals("The response should parse completely", 0, responseBuffer.remaining())
|
||||||
|
assertEquals("The correlationId should match request", correlationId, responseHeader.correlationId())
|
||||||
|
assertEquals("One partition response should be returned", 1, produceResponse.responses().size())
|
||||||
|
|
||||||
|
val partitionResponse = produceResponse.responses().get(topicPartition)
|
||||||
|
assertNotNull(partitionResponse)
|
||||||
|
assertEquals("There should be no error", 0, partitionResponse.errorCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
def testHeaderOnlyRequest() {
|
||||||
|
verifyDisconnect(requestHeaderBytes(ApiKeys.PRODUCE.id, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
def testInvalidApiKeyRequest() {
|
||||||
|
verifyDisconnect(requestHeaderBytes(-1, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
def testInvalidApiVersionRequest() {
|
||||||
|
verifyDisconnect(requestHeaderBytes(ApiKeys.PRODUCE.id, -1))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
def testMalformedHeaderRequest() {
|
||||||
|
val serializedBytes = {
|
||||||
|
// Only send apiKey and apiVersion
|
||||||
|
val buffer = ByteBuffer.allocate(
|
||||||
|
2 /* apiKey */ +
|
||||||
|
2 /* apiVersion */
|
||||||
|
)
|
||||||
|
buffer.putShort(ApiKeys.PRODUCE.id)
|
||||||
|
buffer.putShort(1)
|
||||||
|
buffer.array()
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyDisconnect(serializedBytes)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue