diff --git a/spring-core/src/main/java/org/springframework/aot/hint/SimpleTypeReference.java b/spring-core/src/main/java/org/springframework/aot/hint/SimpleTypeReference.java index 9d7300f45e8..7105738d91c 100644 --- a/spring-core/src/main/java/org/springframework/aot/hint/SimpleTypeReference.java +++ b/spring-core/src/main/java/org/springframework/aot/hint/SimpleTypeReference.java @@ -16,6 +16,8 @@ package org.springframework.aot.hint; +import javax.lang.model.SourceVersion; + import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -26,6 +28,7 @@ import org.springframework.util.Assert; */ final class SimpleTypeReference extends AbstractTypeReference { + @Nullable private String canonicalName; private final String packageName; @@ -44,6 +47,9 @@ final class SimpleTypeReference extends AbstractTypeReference { static SimpleTypeReference of(String className) { Assert.notNull(className, "ClassName must not be null"); + if (!isValidClassName(className)) { + throw new IllegalStateException("Invalid class name '" + className + "'"); + } if (!className.contains("$")) { return createTypeReference(className); } @@ -55,9 +61,19 @@ final class SimpleTypeReference extends AbstractTypeReference { return typeReference; } + private static boolean isValidClassName(String className) { + for (String s : className.split("\\.", -1)) { + if (!SourceVersion.isIdentifier(s)) { + return false; + } + } + return true; + } + private static SimpleTypeReference createTypeReference(String className) { int i = className.lastIndexOf('.'); - return new SimpleTypeReference(className.substring(0, i), className.substring(i + 1), null); + return (i != -1 ? new SimpleTypeReference(className.substring(0, i), className.substring(i + 1), null) + : new SimpleTypeReference("", className, null)); } @Override @@ -65,12 +81,13 @@ final class SimpleTypeReference extends AbstractTypeReference { if (this.canonicalName == null) { StringBuilder names = new StringBuilder(); buildName(this, names); - this.canonicalName = this.packageName + "." + names; + this.canonicalName = (this.packageName.isEmpty() + ? names.toString() : this.packageName + "." + names); } return this.canonicalName; } - private static void buildName(TypeReference type, StringBuilder sb) { + private static void buildName(@Nullable TypeReference type, StringBuilder sb) { if (type == null) { return; } diff --git a/spring-core/src/test/java/org/springframework/aot/hint/SimpleTypeReferenceTests.java b/spring-core/src/test/java/org/springframework/aot/hint/SimpleTypeReferenceTests.java new file mode 100644 index 00000000000..6ce3fb68a30 --- /dev/null +++ b/spring-core/src/test/java/org/springframework/aot/hint/SimpleTypeReferenceTests.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.aot.hint; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +/** + * Tests for {@link SimpleTypeReference}. + * + * @author Stephane Nicoll + */ +class SimpleTypeReferenceTests { + + @Test + void typeReferenceInRootPackage() { + TypeReference type = SimpleTypeReference.of("MyRootClass"); + assertThat(type.getCanonicalName()).isEqualTo("MyRootClass"); + assertThat(type.getPackageName()).isEqualTo(""); + } + + @ParameterizedTest(name = "{0}") + @ValueSource(strings = { "com.example.Tes(t", "com.example..Test" }) + void typeReferenceWithInvalidClassName(String invalidClassName) { + assertThatIllegalStateException().isThrownBy(() -> SimpleTypeReference.of(invalidClassName)) + .withMessageContaining("Invalid class name"); + } + +}