From 89e12f3c6b4f37c405455d5cb0fca60f2be3ff92 Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Wed, 12 Aug 2020 08:29:59 -0700 Subject: [PATCH] KAFKA-10388; Fix struct conversion logic for tagged structures (#9166) The message generator was missing conversion logic for tagged structures. This led to casting errors when either `fromStruct` or `toStruct` were invoked. This patch also adds missing null checks in the serialization of tagged byte arrays, which was found from improved test coverage. Reviewers: Colin P. McCabe --- .../message/SimpleExampleMessageTest.java | 50 ++++++++++++++++--- .../kafka/message/MessageDataGenerator.java | 25 ++++++++-- 2 files changed, 63 insertions(+), 12 deletions(-) diff --git a/clients/src/test/java/org/apache/kafka/common/message/SimpleExampleMessageTest.java b/clients/src/test/java/org/apache/kafka/common/message/SimpleExampleMessageTest.java index 6e229ae1e1f..e2c79358026 100644 --- a/clients/src/test/java/org/apache/kafka/common/message/SimpleExampleMessageTest.java +++ b/clients/src/test/java/org/apache/kafka/common/message/SimpleExampleMessageTest.java @@ -19,6 +19,7 @@ package org.apache.kafka.common.message; import org.apache.kafka.common.errors.UnsupportedVersionException; import org.apache.kafka.common.protocol.ByteBufferAccessor; import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.types.Schema; import org.apache.kafka.common.protocol.types.Struct; import org.apache.kafka.common.utils.ByteUtils; import org.junit.Test; @@ -321,6 +322,38 @@ public class SimpleExampleMessageTest { message -> assertEquals("abc", message.myString()), (short) 2); } + private ByteBuffer serialize(SimpleExampleMessageData message, short version) { + ObjectSerializationCache cache = new ObjectSerializationCache(); + int size = message.size(cache, version); + ByteBuffer buf = ByteBuffer.allocate(size); + message.write(new ByteBufferAccessor(buf), cache, version); + buf.flip(); + assertEquals(size, buf.remaining()); + return buf; + } + + private SimpleExampleMessageData deserialize(ByteBuffer buf, short version) { + SimpleExampleMessageData message = new SimpleExampleMessageData(); + message.read(new ByteBufferAccessor(buf.duplicate()), version); + return message; + } + + private ByteBuffer serializeThroughStruct(SimpleExampleMessageData message, short version) { + Struct struct = message.toStruct(version); + int size = struct.sizeOf(); + ByteBuffer buf = ByteBuffer.allocate(size); + struct.writeTo(buf); + buf.flip(); + assertEquals(size, buf.remaining()); + return buf; + } + + private SimpleExampleMessageData deserializeThroughStruct(ByteBuffer buf, short version) { + Schema schema = SimpleExampleMessageData.SCHEMAS[version]; + Struct struct = schema.read(buf); + return new SimpleExampleMessageData(struct, version); + } + private void testRoundTrip(SimpleExampleMessageData message, Consumer validator) { testRoundTrip(message, validator, (short) 1); @@ -330,17 +363,18 @@ public class SimpleExampleMessageTest { Consumer validator, short version) { validator.accept(message); - ObjectSerializationCache cache = new ObjectSerializationCache(); - int size = message.size(cache, version); - ByteBuffer buf = ByteBuffer.allocate(size); - message.write(new ByteBufferAccessor(buf), cache, version); - buf.flip(); - assertEquals(size, buf.remaining()); + ByteBuffer buf = serialize(message, version); - SimpleExampleMessageData message2 = new SimpleExampleMessageData(); - message2.read(new ByteBufferAccessor(buf.duplicate()), version); + SimpleExampleMessageData message2 = deserialize(buf.duplicate(), version); validator.accept(message2); assertEquals(message, message2); assertEquals(message.hashCode(), message2.hashCode()); + + // Check struct serialization as well + assertEquals(buf, serializeThroughStruct(message, version)); + SimpleExampleMessageData messageFromStruct = deserializeThroughStruct(buf.duplicate(), version); + validator.accept(messageFromStruct); + assertEquals(message, messageFromStruct); + assertEquals(message.hashCode(), messageFromStruct.hashCode()); } } diff --git a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java index 137e86b5295..e524b10796b 100644 --- a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java +++ b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java @@ -793,10 +793,25 @@ public final class MessageDataGenerator { generateArrayFromStruct(field, presentAndTaggedVersions); } else if (field.type().isBytes()) { headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS); - headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); - buffer.printf("this.%s = MessageUtil.byteBufferToArray(" + - "(ByteBuffer) _taggedFields.remove(%d));%n", + buffer.printf("ByteBuffer _byteBuffer = (ByteBuffer) _taggedFields.remove(%d);%n", + field.tag().get()); + + IsNullConditional.forName("_byteBuffer"). + nullableVersions(field.nullableVersions()). + possibleVersions(field.versions()). + ifNull(() -> { + buffer.printf("this.%s = null;%n", field.camelCaseName()); + }). + ifShouldNotBeNull(() -> { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("this.%s = MessageUtil.byteBufferToArray(_byteBuffer);%n", + field.camelCaseName()); + }). + generate(buffer); + } else if (field.type().isStruct()) { + buffer.printf("this.%s = new %s((Struct) _taggedFields.remove(%d), _version);%n", field.camelCaseName(), + getBoxedJavaType(field.type()), field.tag().get()); } else { buffer.printf("this.%s = (%s) _taggedFields.remove(%d);%n", @@ -1731,10 +1746,12 @@ public final class MessageDataGenerator { (field.type() instanceof FieldType.Int64FieldType) || (field.type() instanceof FieldType.UUIDFieldType) || (field.type() instanceof FieldType.Float64FieldType) || - (field.type() instanceof FieldType.StructType) || (field.type() instanceof FieldType.StringFieldType)) { buffer.printf("_taggedFields.put(%d, %s);%n", field.tag().get(), field.camelCaseName()); + } else if (field.type().isStruct()) { + buffer.printf("_taggedFields.put(%d, %s.toStruct(_version));%n", + field.tag().get(), field.camelCaseName()); } else if (field.type().isBytes()) { headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS); if (field.taggedVersions().intersect(field.nullableVersions()).empty()) {