KAFKA-8168; Add a generated ApiMessageType class

Author: Colin P. Mccabe <cmccabe@confluent.io>

Reviewers: Gwen Shapira

Closes #6510 from cmccabe/KAFKA-8168
This commit is contained in:
Colin P. Mccabe 2019-04-05 15:27:34 -07:00 committed by Gwen Shapira
parent cc4fde35c9
commit 62381bd915
6 changed files with 361 additions and 134 deletions

View File

@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.message;
import org.apache.kafka.common.errors.UnsupportedVersionException;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import java.util.HashSet;
import java.util.Set;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
public class ApiMessageTypeTest {
@Rule
final public Timeout globalTimeout = Timeout.millis(120000);
@Test
public void testFromApiKey() {
for (ApiMessageType type : ApiMessageType.values()) {
ApiMessageType type2 = ApiMessageType.fromApiKey(type.apiKey());
assertEquals(type2, type);
}
}
@Test
public void testInvalidFromApiKey() {
try {
ApiMessageType.fromApiKey((short) -1);
fail("expected to get an UnsupportedVersionException");
} catch (UnsupportedVersionException uve) {
// expected
}
}
@Test
public void testUniqueness() {
Set<Short> ids = new HashSet<>();
Set<String> requestNames = new HashSet<>();
Set<String> responseNames = new HashSet<>();
for (ApiMessageType type : ApiMessageType.values()) {
assertFalse("found two ApiMessageType objects with id " + type.apiKey(),
ids.contains(type.apiKey()));
ids.add(type.apiKey());
String requestName = type.newRequest().getClass().getSimpleName();
assertFalse("found two ApiMessageType objects with requestName " + requestName,
requestNames.contains(requestName));
requestNames.add(requestName);
String responseName = type.newResponse().getClass().getSimpleName();
assertFalse("found two ApiMessageType objects with responseName " + responseName,
responseNames.contains(responseName));
responseNames.add(responseName);
}
assertEquals(ApiMessageType.values().length, ids.size());
assertEquals(ApiMessageType.values().length, requestNames.size());
assertEquals(ApiMessageType.values().length, responseNames.size());
}
}

View File

@ -134,7 +134,7 @@ public final class MessageTest {
for (ApiKeys apiKey : ApiKeys.values()) { for (ApiKeys apiKey : ApiKeys.values()) {
Message message = null; Message message = null;
try { try {
message = ApiMessageFactory.newRequest(apiKey.id); message = ApiMessageType.fromApiKey(apiKey.id).newRequest();
} catch (UnsupportedVersionException e) { } catch (UnsupportedVersionException e) {
fail("No request message spec found for API " + apiKey); fail("No request message spec found for API " + apiKey);
} }
@ -142,7 +142,7 @@ public final class MessageTest {
"supports versions up to " + message.highestSupportedVersion(), "supports versions up to " + message.highestSupportedVersion(),
apiKey.latestVersion() <= message.highestSupportedVersion()); apiKey.latestVersion() <= message.highestSupportedVersion());
try { try {
message = ApiMessageFactory.newResponse(apiKey.id); message = ApiMessageType.fromApiKey(apiKey.id).newResponse();
} catch (UnsupportedVersionException e) { } catch (UnsupportedVersionException e) {
fail("No response message spec found for API " + apiKey); fail("No response message spec found for API " + apiKey);
} }
@ -159,7 +159,7 @@ public final class MessageTest {
public void testRequestSchemas() throws Exception { public void testRequestSchemas() throws Exception {
for (ApiKeys apiKey : ApiKeys.values()) { for (ApiKeys apiKey : ApiKeys.values()) {
Schema[] manualSchemas = apiKey.requestSchemas; Schema[] manualSchemas = apiKey.requestSchemas;
Schema[] generatedSchemas = ApiMessageFactory.requestSchemas(apiKey.id); Schema[] generatedSchemas = ApiMessageType.fromApiKey(apiKey.id).requestSchemas();
Assert.assertEquals("Mismatching request SCHEMAS lengths " + Assert.assertEquals("Mismatching request SCHEMAS lengths " +
"for api key " + apiKey, manualSchemas.length, generatedSchemas.length); "for api key " + apiKey, manualSchemas.length, generatedSchemas.length);
for (int v = 0; v < manualSchemas.length; v++) { for (int v = 0; v < manualSchemas.length; v++) {
@ -182,7 +182,7 @@ public final class MessageTest {
public void testResponseSchemas() throws Exception { public void testResponseSchemas() throws Exception {
for (ApiKeys apiKey : ApiKeys.values()) { for (ApiKeys apiKey : ApiKeys.values()) {
Schema[] manualSchemas = apiKey.responseSchemas; Schema[] manualSchemas = apiKey.responseSchemas;
Schema[] generatedSchemas = ApiMessageFactory.responseSchemas(apiKey.id); Schema[] generatedSchemas = ApiMessageType.fromApiKey(apiKey.id).responseSchemas();
Assert.assertEquals("Mismatching response SCHEMAS lengths " + Assert.assertEquals("Mismatching response SCHEMAS lengths " +
"for api key " + apiKey, manualSchemas.length, generatedSchemas.length); "for api key " + apiKey, manualSchemas.length, generatedSchemas.length);
for (int v = 0; v < manualSchemas.length; v++) { for (int v = 0; v < manualSchemas.length; v++) {

View File

@ -1,123 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.message;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.Map;
import java.util.TreeMap;
public final class ApiMessageFactoryGenerator {
private final TreeMap<Short, String> requestApis;
private final TreeMap<Short, String> responseApis;
private final HeaderGenerator headerGenerator;
private final CodeBuffer buffer;
public void registerMessageType(MessageSpec spec) {
if (spec.type() == MessageSpecType.REQUEST) {
if (requestApis.containsKey(spec.apiKey().get())) {
throw new RuntimeException("Found more than one request with " +
"API key " + spec.apiKey().get());
}
requestApis.put(spec.apiKey().get(), spec.generatedClassName());
} else if (spec.type() == MessageSpecType.RESPONSE) {
if (responseApis.containsKey(spec.apiKey().get())) {
throw new RuntimeException("Found more than one response with " +
"API key " + spec.apiKey().get());
}
responseApis.put(spec.apiKey().get(), spec.generatedClassName());
}
}
public ApiMessageFactoryGenerator() {
this.requestApis = new TreeMap<>();
this.responseApis = new TreeMap<>();
this.headerGenerator = new HeaderGenerator();
this.buffer = new CodeBuffer();
}
public void generate() {
buffer.printf("public final class ApiMessageFactory {%n");
buffer.incrementIndent();
generateFactoryMethod("request", requestApis);
buffer.printf("%n");
generateFactoryMethod("response", responseApis);
buffer.printf("%n");
generateSchemasAccessor("request", requestApis);
buffer.printf("%n");
generateSchemasAccessor("response", responseApis);
buffer.decrementIndent();
buffer.printf("}%n");
headerGenerator.generate();
}
public void generateFactoryMethod(String type, TreeMap<Short, String> apis) {
headerGenerator.addImport(MessageGenerator.MESSAGE_CLASS);
buffer.printf("public static Message new%s(short apiKey) {%n",
MessageGenerator.capitalizeFirst(type));
buffer.incrementIndent();
buffer.printf("switch (apiKey) {%n");
buffer.incrementIndent();
for (Map.Entry<Short, String> entry : apis.entrySet()) {
buffer.printf("case %d:%n", entry.getKey());
buffer.incrementIndent();
buffer.printf("return new %s();%n", entry.getValue());
buffer.decrementIndent();
}
buffer.printf("default:%n");
buffer.incrementIndent();
headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS);
buffer.printf("throw new UnsupportedVersionException(\"Unsupported %s API key \"" +
" + apiKey);%n", type);
buffer.decrementIndent();
buffer.decrementIndent();
buffer.printf("}%n");
buffer.decrementIndent();
buffer.printf("}%n");
}
public void generateSchemasAccessor(String type, TreeMap<Short, String> apis) {
headerGenerator.addImport(MessageGenerator.SCHEMA_CLASS);
buffer.printf("public static Schema[] %sSchemas(short apiKey) {%n",
MessageGenerator.lowerCaseFirst(type));
buffer.incrementIndent();
buffer.printf("switch (apiKey) {%n");
buffer.incrementIndent();
for (Map.Entry<Short, String> entry : apis.entrySet()) {
buffer.printf("case %d:%n", entry.getKey());
buffer.incrementIndent();
buffer.printf("return %s.SCHEMAS;%n", entry.getValue());
buffer.decrementIndent();
}
buffer.printf("default:%n");
buffer.incrementIndent();
headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS);
buffer.printf("throw new UnsupportedVersionException(\"Unsupported %s API key \"" +
" + apiKey);%n", type);
buffer.decrementIndent();
buffer.decrementIndent();
buffer.printf("}%n");
buffer.decrementIndent();
buffer.printf("}%n");
}
public void write(BufferedWriter writer) throws IOException {
headerGenerator.buffer().write(writer);
buffer.write(writer);
}
}

View File

@ -0,0 +1,248 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.message;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;
public final class ApiMessageTypeGenerator {
private final HeaderGenerator headerGenerator;
private final CodeBuffer buffer;
private final TreeMap<Short, ApiData> apis;
private static final class ApiData {
short apiKey;
MessageSpec requestSpec;
MessageSpec responseSpec;
ApiData(short apiKey) {
this.apiKey = apiKey;
}
String name() {
if (requestSpec != null) {
return MessageGenerator.stripSuffix(requestSpec.name(),
MessageGenerator.REQUEST_SUFFIX);
} else if (responseSpec != null) {
return MessageGenerator.stripSuffix(responseSpec.name(),
MessageGenerator.RESPONSE_SUFFIX);
} else {
throw new RuntimeException("Neither requestSpec nor responseSpec is defined " +
"for API key " + apiKey);
}
}
String requestSchema() {
if (requestSpec == null) {
return "null";
} else {
return String.format("%sData.SCHEMAS", requestSpec.name());
}
}
String responseSchema() {
if (responseSpec == null) {
return "null";
} else {
return String.format("%sData.SCHEMAS", responseSpec.name());
}
}
}
public ApiMessageTypeGenerator() {
this.headerGenerator = new HeaderGenerator();
this.apis = new TreeMap<>();
this.buffer = new CodeBuffer();
}
public void registerMessageType(MessageSpec spec) {
switch (spec.type()) {
case REQUEST: {
short apiKey = spec.apiKey().get();
ApiData data = apis.get(apiKey);
if (!apis.containsKey(apiKey)) {
data = new ApiData(apiKey);
apis.put(apiKey, data);
}
if (data.requestSpec != null) {
throw new RuntimeException("Found more than one request with " +
"API key " + spec.apiKey().get());
}
data.requestSpec = spec;
break;
}
case RESPONSE: {
short apiKey = spec.apiKey().get();
ApiData data = apis.get(apiKey);
if (!apis.containsKey(apiKey)) {
data = new ApiData(apiKey);
apis.put(apiKey, data);
}
if (data.responseSpec != null) {
throw new RuntimeException("Found more than one response with " +
"API key " + spec.apiKey().get());
}
data.responseSpec = spec;
break;
}
default:
// do nothing
break;
}
}
public void generate() {
buffer.printf("public enum ApiMessageType {%n");
buffer.incrementIndent();
generateEnumValues();
buffer.printf("%n");
generateInstanceVariables();
buffer.printf("%n");
generateEnumConstructor();
buffer.printf("%n");
generateFromApiKey();
buffer.printf("%n");
generateNewApiMessageMethod("request");
buffer.printf("%n");
generateNewApiMessageMethod("response");
buffer.printf("%n");
generateAccessor("apiKey", "short");
buffer.printf("%n");
generateAccessor("requestSchemas", "Schema[]");
buffer.printf("%n");
generateAccessor("responseSchemas", "Schema[]");
buffer.printf("%n");
generateToString();
buffer.decrementIndent();
buffer.printf("}%n");
headerGenerator.generate();
}
private void generateEnumValues() {
int numProcessed = 0;
for (Map.Entry<Short, ApiData> entry : apis.entrySet()) {
ApiData apiData = entry.getValue();
String name = apiData.name();
numProcessed++;
buffer.printf("%s(\"%s\", (short) %d, %s, %s)%s%n",
MessageGenerator.toSnakeCase(name).toUpperCase(Locale.ROOT),
MessageGenerator.capitalizeFirst(name),
entry.getKey(),
apiData.requestSchema(),
apiData.responseSchema(),
(numProcessed == apis.size()) ? ";" : ",");
}
}
private void generateInstanceVariables() {
buffer.printf("private final String name;%n");
buffer.printf("private final short apiKey;%n");
buffer.printf("private final Schema[] requestSchemas;%n");
buffer.printf("private final Schema[] responseSchemas;%n");
headerGenerator.addImport(MessageGenerator.SCHEMA_CLASS);
}
private void generateEnumConstructor() {
buffer.printf("ApiMessageType(String name, short apiKey, " +
"Schema[] requestSchemas, Schema[] responseSchemas) {%n");
buffer.incrementIndent();
buffer.printf("this.name = name;%n");
buffer.printf("this.apiKey = apiKey;%n");
buffer.printf("this.requestSchemas = requestSchemas;%n");
buffer.printf("this.responseSchemas = responseSchemas;%n");
buffer.decrementIndent();
buffer.printf("}%n");
}
private void generateFromApiKey() {
buffer.printf("public static ApiMessageType fromApiKey(short apiKey) {%n");
buffer.incrementIndent();
buffer.printf("switch (apiKey) {%n");
buffer.incrementIndent();
for (Map.Entry<Short, ApiData> entry : apis.entrySet()) {
ApiData apiData = entry.getValue();
String name = apiData.name();
buffer.printf("case %d:%n", entry.getKey());
buffer.incrementIndent();
buffer.printf("return %s;%n", MessageGenerator.toSnakeCase(name).toUpperCase(Locale.ROOT));
buffer.decrementIndent();
}
buffer.printf("default:%n");
buffer.incrementIndent();
headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS);
buffer.printf("throw new UnsupportedVersionException(\"Unsupported API key \"" +
" + apiKey);%n");
buffer.decrementIndent();
buffer.decrementIndent();
buffer.printf("}%n");
buffer.decrementIndent();
buffer.printf("}%n");
}
private void generateNewApiMessageMethod(String type) {
headerGenerator.addImport(MessageGenerator.API_MESSAGE_CLASS);
buffer.printf("public ApiMessage new%s() {%n",
MessageGenerator.capitalizeFirst(type));
buffer.incrementIndent();
buffer.printf("switch (apiKey) {%n");
buffer.incrementIndent();
for (Map.Entry<Short, ApiData> entry : apis.entrySet()) {
buffer.printf("case %d:%n", entry.getKey());
buffer.incrementIndent();
buffer.printf("return new %s%sData();%n",
entry.getValue().name(),
MessageGenerator.capitalizeFirst(type));
buffer.decrementIndent();
}
buffer.printf("default:%n");
buffer.incrementIndent();
headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS);
buffer.printf("throw new UnsupportedVersionException(\"Unsupported %s API key \"" +
" + apiKey);%n", type);
buffer.decrementIndent();
buffer.decrementIndent();
buffer.printf("}%n");
buffer.decrementIndent();
buffer.printf("}%n");
}
private void generateAccessor(String name, String type) {
buffer.printf("public %s %s() {%n", type, name);
buffer.incrementIndent();
buffer.printf("return this.%s;%n", name);
buffer.decrementIndent();
buffer.printf("}%n");
}
private void generateToString() {
buffer.printf("@Override%n");
buffer.printf("public String toString() {%n");
buffer.incrementIndent();
buffer.printf("return this.name();%n");
buffer.decrementIndent();
buffer.printf("}%n");
}
public void write(BufferedWriter writer) throws IOException {
headerGenerator.buffer().write(writer);
buffer.write(writer);
}
}

View File

@ -41,7 +41,7 @@ public final class MessageGenerator {
static final String JAVA_SUFFIX = ".java"; static final String JAVA_SUFFIX = ".java";
static final String API_MESSAGE_FACTORY_JAVA = "ApiMessageFactory.java"; static final String API_MESSAGE_TYPE_JAVA = "ApiMessageType.java";
static final String API_MESSAGE_CLASS = "org.apache.kafka.common.protocol.ApiMessage"; static final String API_MESSAGE_CLASS = "org.apache.kafka.common.protocol.ApiMessage";
@ -79,6 +79,10 @@ public final class MessageGenerator {
static final String BYTES_CLASS = "org.apache.kafka.common.utils.Bytes"; static final String BYTES_CLASS = "org.apache.kafka.common.utils.Bytes";
static final String REQUEST_SUFFIX = "Request";
static final String RESPONSE_SUFFIX = "Response";
/** /**
* The Jackson serializer we use for JSON objects. * The Jackson serializer we use for JSON objects.
*/ */
@ -95,7 +99,7 @@ public final class MessageGenerator {
public static void processDirectories(String outputDir, String inputDir) throws Exception { public static void processDirectories(String outputDir, String inputDir) throws Exception {
Files.createDirectories(Paths.get(outputDir)); Files.createDirectories(Paths.get(outputDir));
int numProcessed = 0; int numProcessed = 0;
ApiMessageFactoryGenerator messageFactoryGenerator = new ApiMessageFactoryGenerator(); ApiMessageTypeGenerator messageTypeGenerator = new ApiMessageTypeGenerator();
HashSet<String> outputFileNames = new HashSet<>(); HashSet<String> outputFileNames = new HashSet<>();
try (DirectoryStream<Path> directoryStream = Files try (DirectoryStream<Path> directoryStream = Files
.newDirectoryStream(Paths.get(inputDir), JSON_GLOB)) { .newDirectoryStream(Paths.get(inputDir), JSON_GLOB)) {
@ -112,17 +116,17 @@ public final class MessageGenerator {
generator.write(writer); generator.write(writer);
} }
numProcessed++; numProcessed++;
messageFactoryGenerator.registerMessageType(spec); messageTypeGenerator.registerMessageType(spec);
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("Exception while processing " + inputPath.toString(), e); throw new RuntimeException("Exception while processing " + inputPath.toString(), e);
} }
} }
} }
Path factoryOutputPath = Paths.get(outputDir, API_MESSAGE_FACTORY_JAVA); Path factoryOutputPath = Paths.get(outputDir, API_MESSAGE_TYPE_JAVA);
outputFileNames.add(API_MESSAGE_FACTORY_JAVA); outputFileNames.add(API_MESSAGE_TYPE_JAVA);
try (BufferedWriter writer = Files.newBufferedWriter(factoryOutputPath)) { try (BufferedWriter writer = Files.newBufferedWriter(factoryOutputPath)) {
messageFactoryGenerator.generate(); messageTypeGenerator.generate();
messageFactoryGenerator.write(writer); messageTypeGenerator.write(writer);
} }
numProcessed++; numProcessed++;
try (DirectoryStream<Path> directoryStream = Files. try (DirectoryStream<Path> directoryStream = Files.
@ -181,6 +185,15 @@ public final class MessageGenerator {
return bld.toString(); return bld.toString();
} }
static String stripSuffix(String str, String suffix) {
if (str.endsWith(suffix)) {
return str.substring(0, str.length() - suffix.length());
} else {
throw new RuntimeException("String " + str + " does not end with the " +
"expected suffix " + suffix);
}
}
private final static String USAGE = "MessageGenerator: [output Java file] [input JSON file]"; private final static String USAGE = "MessageGenerator: [output Java file] [input JSON file]";
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {

View File

@ -24,6 +24,7 @@ import org.junit.rules.Timeout;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
public class MessageGeneratorTest { public class MessageGeneratorTest {
@Rule @Rule
@ -56,4 +57,16 @@ public class MessageGeneratorTest {
assertEquals("foo_bar_baz", MessageGenerator.toSnakeCase("fooBarBaz")); assertEquals("foo_bar_baz", MessageGenerator.toSnakeCase("fooBarBaz"));
assertEquals("fortran", MessageGenerator.toSnakeCase("FORTRAN")); assertEquals("fortran", MessageGenerator.toSnakeCase("FORTRAN"));
} }
@Test
public void stripSuffixTest() throws Exception {
assertEquals("FooBa", MessageGenerator.stripSuffix("FooBar", "r"));
assertEquals("", MessageGenerator.stripSuffix("FooBar", "FooBar"));
assertEquals("Foo", MessageGenerator.stripSuffix("FooBar", "Bar"));
try {
MessageGenerator.stripSuffix("FooBar", "Baz");
fail("expected exception");
} catch (RuntimeException e) {
}
}
} }