KAFKA-10388; Fix struct conversion logic for tagged structures (#9166)

The message generator was missing conversion logic for tagged structures. This led to casting errors when either `fromStruct` or `toStruct` were invoked. This patch also adds missing null checks in the serialization of tagged byte arrays, which was found from improved test coverage.

Reviewers: Colin P. McCabe <cmccabe@apache.org>
This commit is contained in:
Jason Gustafson 2020-08-12 08:29:59 -07:00 committed by GitHub
parent 7915d5e5f8
commit 89e12f3c6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 12 deletions

View File

@ -19,6 +19,7 @@ package org.apache.kafka.common.message;
import org.apache.kafka.common.errors.UnsupportedVersionException;
import org.apache.kafka.common.protocol.ByteBufferAccessor;
import org.apache.kafka.common.protocol.ObjectSerializationCache;
import org.apache.kafka.common.protocol.types.Schema;
import org.apache.kafka.common.protocol.types.Struct;
import org.apache.kafka.common.utils.ByteUtils;
import org.junit.Test;
@ -321,6 +322,38 @@ public class SimpleExampleMessageTest {
message -> assertEquals("abc", message.myString()), (short) 2);
}
private ByteBuffer serialize(SimpleExampleMessageData message, short version) {
ObjectSerializationCache cache = new ObjectSerializationCache();
int size = message.size(cache, version);
ByteBuffer buf = ByteBuffer.allocate(size);
message.write(new ByteBufferAccessor(buf), cache, version);
buf.flip();
assertEquals(size, buf.remaining());
return buf;
}
private SimpleExampleMessageData deserialize(ByteBuffer buf, short version) {
SimpleExampleMessageData message = new SimpleExampleMessageData();
message.read(new ByteBufferAccessor(buf.duplicate()), version);
return message;
}
private ByteBuffer serializeThroughStruct(SimpleExampleMessageData message, short version) {
Struct struct = message.toStruct(version);
int size = struct.sizeOf();
ByteBuffer buf = ByteBuffer.allocate(size);
struct.writeTo(buf);
buf.flip();
assertEquals(size, buf.remaining());
return buf;
}
private SimpleExampleMessageData deserializeThroughStruct(ByteBuffer buf, short version) {
Schema schema = SimpleExampleMessageData.SCHEMAS[version];
Struct struct = schema.read(buf);
return new SimpleExampleMessageData(struct, version);
}
private void testRoundTrip(SimpleExampleMessageData message,
Consumer<SimpleExampleMessageData> validator) {
testRoundTrip(message, validator, (short) 1);
@ -330,17 +363,18 @@ public class SimpleExampleMessageTest {
Consumer<SimpleExampleMessageData> validator,
short version) {
validator.accept(message);
ObjectSerializationCache cache = new ObjectSerializationCache();
int size = message.size(cache, version);
ByteBuffer buf = ByteBuffer.allocate(size);
message.write(new ByteBufferAccessor(buf), cache, version);
buf.flip();
assertEquals(size, buf.remaining());
ByteBuffer buf = serialize(message, version);
SimpleExampleMessageData message2 = new SimpleExampleMessageData();
message2.read(new ByteBufferAccessor(buf.duplicate()), version);
SimpleExampleMessageData message2 = deserialize(buf.duplicate(), version);
validator.accept(message2);
assertEquals(message, message2);
assertEquals(message.hashCode(), message2.hashCode());
// Check struct serialization as well
assertEquals(buf, serializeThroughStruct(message, version));
SimpleExampleMessageData messageFromStruct = deserializeThroughStruct(buf.duplicate(), version);
validator.accept(messageFromStruct);
assertEquals(message, messageFromStruct);
assertEquals(message.hashCode(), messageFromStruct.hashCode());
}
}

View File

@ -793,10 +793,25 @@ public final class MessageDataGenerator {
generateArrayFromStruct(field, presentAndTaggedVersions);
} else if (field.type().isBytes()) {
headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS);
headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS);
buffer.printf("this.%s = MessageUtil.byteBufferToArray(" +
"(ByteBuffer) _taggedFields.remove(%d));%n",
buffer.printf("ByteBuffer _byteBuffer = (ByteBuffer) _taggedFields.remove(%d);%n",
field.tag().get());
IsNullConditional.forName("_byteBuffer").
nullableVersions(field.nullableVersions()).
possibleVersions(field.versions()).
ifNull(() -> {
buffer.printf("this.%s = null;%n", field.camelCaseName());
}).
ifShouldNotBeNull(() -> {
headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS);
buffer.printf("this.%s = MessageUtil.byteBufferToArray(_byteBuffer);%n",
field.camelCaseName());
}).
generate(buffer);
} else if (field.type().isStruct()) {
buffer.printf("this.%s = new %s((Struct) _taggedFields.remove(%d), _version);%n",
field.camelCaseName(),
getBoxedJavaType(field.type()),
field.tag().get());
} else {
buffer.printf("this.%s = (%s) _taggedFields.remove(%d);%n",
@ -1731,10 +1746,12 @@ public final class MessageDataGenerator {
(field.type() instanceof FieldType.Int64FieldType) ||
(field.type() instanceof FieldType.UUIDFieldType) ||
(field.type() instanceof FieldType.Float64FieldType) ||
(field.type() instanceof FieldType.StructType) ||
(field.type() instanceof FieldType.StringFieldType)) {
buffer.printf("_taggedFields.put(%d, %s);%n",
field.tag().get(), field.camelCaseName());
} else if (field.type().isStruct()) {
buffer.printf("_taggedFields.put(%d, %s.toStruct(_version));%n",
field.tag().get(), field.camelCaseName());
} else if (field.type().isBytes()) {
headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS);
if (field.taggedVersions().intersect(field.nullableVersions()).empty()) {