mirror of https://github.com/apache/kafka.git
				
				
				
			KAFKA-10388; Fix struct conversion logic for tagged structures (#9166)
The message generator was missing conversion logic for tagged structures. This led to casting errors when either `fromStruct` or `toStruct` were invoked. This patch also adds missing null checks in the serialization of tagged byte arrays, which was found from improved test coverage. Reviewers: Colin P. McCabe <cmccabe@apache.org>
This commit is contained in:
		
							parent
							
								
									7915d5e5f8
								
							
						
					
					
						commit
						89e12f3c6b
					
				|  | @ -19,6 +19,7 @@ package org.apache.kafka.common.message; | |||
| import org.apache.kafka.common.errors.UnsupportedVersionException; | ||||
| import org.apache.kafka.common.protocol.ByteBufferAccessor; | ||||
| import org.apache.kafka.common.protocol.ObjectSerializationCache; | ||||
| import org.apache.kafka.common.protocol.types.Schema; | ||||
| import org.apache.kafka.common.protocol.types.Struct; | ||||
| import org.apache.kafka.common.utils.ByteUtils; | ||||
| import org.junit.Test; | ||||
|  | @ -321,6 +322,38 @@ public class SimpleExampleMessageTest { | |||
|             message -> assertEquals("abc", message.myString()), (short) 2); | ||||
|     } | ||||
| 
 | ||||
|     private ByteBuffer serialize(SimpleExampleMessageData message, short version) { | ||||
|         ObjectSerializationCache cache = new ObjectSerializationCache(); | ||||
|         int size = message.size(cache, version); | ||||
|         ByteBuffer buf = ByteBuffer.allocate(size); | ||||
|         message.write(new ByteBufferAccessor(buf), cache, version); | ||||
|         buf.flip(); | ||||
|         assertEquals(size, buf.remaining()); | ||||
|         return buf; | ||||
|     } | ||||
| 
 | ||||
|     private SimpleExampleMessageData deserialize(ByteBuffer buf, short version) { | ||||
|         SimpleExampleMessageData message = new SimpleExampleMessageData(); | ||||
|         message.read(new ByteBufferAccessor(buf.duplicate()), version); | ||||
|         return message; | ||||
|     } | ||||
| 
 | ||||
|     private ByteBuffer serializeThroughStruct(SimpleExampleMessageData message, short version) { | ||||
|         Struct struct = message.toStruct(version); | ||||
|         int size = struct.sizeOf(); | ||||
|         ByteBuffer buf = ByteBuffer.allocate(size); | ||||
|         struct.writeTo(buf); | ||||
|         buf.flip(); | ||||
|         assertEquals(size, buf.remaining()); | ||||
|         return buf; | ||||
|     } | ||||
| 
 | ||||
|     private SimpleExampleMessageData deserializeThroughStruct(ByteBuffer buf, short version) { | ||||
|         Schema schema = SimpleExampleMessageData.SCHEMAS[version]; | ||||
|         Struct struct = schema.read(buf); | ||||
|         return new SimpleExampleMessageData(struct, version); | ||||
|     } | ||||
| 
 | ||||
|     private void testRoundTrip(SimpleExampleMessageData message, | ||||
|                                Consumer<SimpleExampleMessageData> validator) { | ||||
|         testRoundTrip(message, validator, (short) 1); | ||||
|  | @ -330,17 +363,18 @@ public class SimpleExampleMessageTest { | |||
|                                Consumer<SimpleExampleMessageData> validator, | ||||
|                                short version) { | ||||
|         validator.accept(message); | ||||
|         ObjectSerializationCache cache = new ObjectSerializationCache(); | ||||
|         int size = message.size(cache, version); | ||||
|         ByteBuffer buf = ByteBuffer.allocate(size); | ||||
|         message.write(new ByteBufferAccessor(buf), cache, version); | ||||
|         buf.flip(); | ||||
|         assertEquals(size, buf.remaining()); | ||||
|         ByteBuffer buf = serialize(message, version); | ||||
| 
 | ||||
|         SimpleExampleMessageData message2 = new SimpleExampleMessageData(); | ||||
|         message2.read(new ByteBufferAccessor(buf.duplicate()), version); | ||||
|         SimpleExampleMessageData message2 = deserialize(buf.duplicate(), version); | ||||
|         validator.accept(message2); | ||||
|         assertEquals(message, message2); | ||||
|         assertEquals(message.hashCode(), message2.hashCode()); | ||||
| 
 | ||||
|         // Check struct serialization as well | ||||
|         assertEquals(buf, serializeThroughStruct(message, version)); | ||||
|         SimpleExampleMessageData messageFromStruct = deserializeThroughStruct(buf.duplicate(), version); | ||||
|         validator.accept(messageFromStruct); | ||||
|         assertEquals(message, messageFromStruct); | ||||
|         assertEquals(message.hashCode(), messageFromStruct.hashCode()); | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -793,10 +793,25 @@ public final class MessageDataGenerator { | |||
|                                 generateArrayFromStruct(field, presentAndTaggedVersions); | ||||
|                             } else if (field.type().isBytes()) { | ||||
|                                 headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS); | ||||
|                                 buffer.printf("ByteBuffer _byteBuffer = (ByteBuffer) _taggedFields.remove(%d);%n", | ||||
|                                     field.tag().get()); | ||||
| 
 | ||||
|                                 IsNullConditional.forName("_byteBuffer"). | ||||
|                                     nullableVersions(field.nullableVersions()). | ||||
|                                     possibleVersions(field.versions()). | ||||
|                                     ifNull(() -> { | ||||
|                                         buffer.printf("this.%s = null;%n", field.camelCaseName()); | ||||
|                                     }). | ||||
|                                     ifShouldNotBeNull(() -> { | ||||
|                                         headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); | ||||
|                                 buffer.printf("this.%s = MessageUtil.byteBufferToArray(" + | ||||
|                                     "(ByteBuffer) _taggedFields.remove(%d));%n", | ||||
|                                         buffer.printf("this.%s = MessageUtil.byteBufferToArray(_byteBuffer);%n", | ||||
|                                             field.camelCaseName()); | ||||
|                                     }). | ||||
|                                     generate(buffer); | ||||
|                             } else if (field.type().isStruct()) { | ||||
|                                 buffer.printf("this.%s = new %s((Struct) _taggedFields.remove(%d), _version);%n", | ||||
|                                     field.camelCaseName(), | ||||
|                                     getBoxedJavaType(field.type()), | ||||
|                                     field.tag().get()); | ||||
|                             } else { | ||||
|                                 buffer.printf("this.%s = (%s) _taggedFields.remove(%d);%n", | ||||
|  | @ -1731,10 +1746,12 @@ public final class MessageDataGenerator { | |||
|             (field.type() instanceof FieldType.Int64FieldType) || | ||||
|             (field.type() instanceof FieldType.UUIDFieldType) || | ||||
|             (field.type() instanceof FieldType.Float64FieldType) || | ||||
|             (field.type() instanceof FieldType.StructType) || | ||||
|             (field.type() instanceof FieldType.StringFieldType)) { | ||||
|             buffer.printf("_taggedFields.put(%d, %s);%n", | ||||
|                 field.tag().get(), field.camelCaseName()); | ||||
|         } else if (field.type().isStruct()) { | ||||
|             buffer.printf("_taggedFields.put(%d, %s.toStruct(_version));%n", | ||||
|                 field.tag().get(), field.camelCaseName()); | ||||
|         } else if (field.type().isBytes()) { | ||||
|             headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS); | ||||
|             if (field.taggedVersions().intersect(field.nullableVersions()).empty()) { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue