diff --git a/spring-expression/src/main/java/org/springframework/expression/spel/CodeFlow.java b/spring-expression/src/main/java/org/springframework/expression/spel/CodeFlow.java index b257f47f557..659107ca9b9 100644 --- a/spring-expression/src/main/java/org/springframework/expression/spel/CodeFlow.java +++ b/spring-expression/src/main/java/org/springframework/expression/spel/CodeFlow.java @@ -1017,4 +1017,19 @@ public class CodeFlow implements Opcodes { void generateCode(MethodVisitor mv, CodeFlow codeflow); } + public static String toBoxedDescriptor(String primitiveDescriptor) { + switch (primitiveDescriptor.charAt(0)) { + case 'I': return "Ljava/lang/Integer"; + case 'J': return "Ljava/lang/Long"; + case 'F': return "Ljava/lang/Float"; + case 'D': return "Ljava/lang/Double"; + case 'B': return "Ljava/lang/Byte"; + case 'C': return "Ljava/lang/Character"; + case 'S': return "Ljava/lang/Short"; + case 'Z': return "Ljava/lang/Boolean"; + default: + throw new IllegalArgumentException("Unexpected non primitive descriptor "+primitiveDescriptor); + } + } + } diff --git a/spring-expression/src/main/java/org/springframework/expression/spel/ast/MethodReference.java b/spring-expression/src/main/java/org/springframework/expression/spel/ast/MethodReference.java index 1b68f61ecbb..1c578bac674 100644 --- a/spring-expression/src/main/java/org/springframework/expression/spel/ast/MethodReference.java +++ b/spring-expression/src/main/java/org/springframework/expression/spel/ast/MethodReference.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import org.springframework.asm.Label; import org.springframework.asm.MethodVisitor; import org.springframework.core.convert.TypeDescriptor; import org.springframework.expression.AccessException; @@ -56,6 +57,8 @@ public class MethodReference extends SpelNodeImpl { private final boolean nullSafe; + private String originalPrimitiveExitTypeDescriptor = null; + @Nullable private volatile CachedMethodExecutor cachedExecutor; @@ -236,7 +239,14 @@ public class MethodReference extends SpelNodeImpl { CachedMethodExecutor executorToCheck = this.cachedExecutor; if (executorToCheck != null && executorToCheck.get() instanceof ReflectiveMethodExecutor) { Method method = ((ReflectiveMethodExecutor) executorToCheck.get()).getMethod(); - this.exitTypeDescriptor = CodeFlow.toDescriptor(method.getReturnType()); + String descriptor = CodeFlow.toDescriptor(method.getReturnType()); + if (this.nullSafe && CodeFlow.isPrimitive(descriptor)) { + originalPrimitiveExitTypeDescriptor = descriptor; + this.exitTypeDescriptor = CodeFlow.toBoxedDescriptor(descriptor); + } + else { + this.exitTypeDescriptor = descriptor; + } } } @@ -296,17 +306,23 @@ public class MethodReference extends SpelNodeImpl { boolean isStaticMethod = Modifier.isStatic(method.getModifiers()); String descriptor = cf.lastDescriptor(); - if (descriptor == null) { - if (!isStaticMethod) { - // Nothing on the stack but something is needed - cf.loadTarget(mv); - } + Label skipIfNull = null; + if (descriptor == null && !isStaticMethod) { + // Nothing on the stack but something is needed + cf.loadTarget(mv); } - else { - if (isStaticMethod) { - // Something on the stack when nothing is needed - mv.visitInsn(POP); - } + if ((descriptor != null || !isStaticMethod) && nullSafe) { + mv.visitInsn(DUP); + skipIfNull = new Label(); + Label continueLabel = new Label(); + mv.visitJumpInsn(IFNONNULL,continueLabel); + CodeFlow.insertCheckCast(mv, this.exitTypeDescriptor); + mv.visitJumpInsn(GOTO, skipIfNull); + mv.visitLabel(continueLabel); + } + if (descriptor != null && isStaticMethod) { + // Something on the stack when nothing is needed + mv.visitInsn(POP); } if (CodeFlow.isPrimitive(descriptor)) { @@ -333,6 +349,14 @@ public class MethodReference extends SpelNodeImpl { mv.visitMethodInsn((isStaticMethod ? INVOKESTATIC : INVOKEVIRTUAL), classDesc, method.getName(), CodeFlow.createSignatureDescriptor(method), method.getDeclaringClass().isInterface()); cf.pushDescriptor(this.exitTypeDescriptor); + if (originalPrimitiveExitTypeDescriptor != null) { + // The output of the accessor will be a primitive but from the block above it might be null, + // so to have a 'common stack' element at skipIfNull target we need to box the primitive + CodeFlow.insertBoxIfNecessary(mv, originalPrimitiveExitTypeDescriptor); + } + if (skipIfNull != null) { + mv.visitLabel(skipIfNull); + } } diff --git a/spring-expression/src/main/java/org/springframework/expression/spel/ast/PropertyOrFieldReference.java b/spring-expression/src/main/java/org/springframework/expression/spel/ast/PropertyOrFieldReference.java index cbc74a9459c..3653e93b0c2 100644 --- a/spring-expression/src/main/java/org/springframework/expression/spel/ast/PropertyOrFieldReference.java +++ b/spring-expression/src/main/java/org/springframework/expression/spel/ast/PropertyOrFieldReference.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.springframework.asm.Label; import org.springframework.asm.MethodVisitor; import org.springframework.core.convert.TypeDescriptor; import org.springframework.expression.AccessException; @@ -51,6 +52,8 @@ public class PropertyOrFieldReference extends SpelNodeImpl { private final boolean nullSafe; + private String originalPrimitiveExitTypeDescriptor = null; + private final String name; @Nullable @@ -89,7 +92,7 @@ public class PropertyOrFieldReference extends SpelNodeImpl { PropertyAccessor accessorToUse = this.cachedReadAccessor; if (accessorToUse instanceof CompilablePropertyAccessor) { CompilablePropertyAccessor accessor = (CompilablePropertyAccessor) accessorToUse; - this.exitTypeDescriptor = CodeFlow.toDescriptor(accessor.getPropertyType()); + setExitTypeDescriptor(CodeFlow.toDescriptor(accessor.getPropertyType())); } return tv; } @@ -338,8 +341,40 @@ public class PropertyOrFieldReference extends SpelNodeImpl { if (!(accessorToUse instanceof CompilablePropertyAccessor)) { throw new IllegalStateException("Property accessor is not compilable: " + accessorToUse); } + Label skipIfNull = null; + if (nullSafe) { + mv.visitInsn(DUP); + skipIfNull = new Label(); + Label continueLabel = new Label(); + mv.visitJumpInsn(IFNONNULL,continueLabel); + CodeFlow.insertCheckCast(mv, this.exitTypeDescriptor); + mv.visitJumpInsn(GOTO, skipIfNull); + mv.visitLabel(continueLabel); + } ((CompilablePropertyAccessor) accessorToUse).generateCode(this.name, mv, cf); cf.pushDescriptor(this.exitTypeDescriptor); + if (originalPrimitiveExitTypeDescriptor != null) { + // The output of the accessor is a primitive but from the block above it might be null, + // so to have a common stack element type at skipIfNull target it is necessary + // to box the primitive + CodeFlow.insertBoxIfNecessary(mv, originalPrimitiveExitTypeDescriptor); + } + if (skipIfNull != null) { + mv.visitLabel(skipIfNull); + } + } + + void setExitTypeDescriptor(String descriptor) { + // If this property or field access would return a primitive - and yet + // it is also marked null safe - then the exit type descriptor must be + // promoted to the box type to allow a null value to be passed on + if (this.nullSafe && CodeFlow.isPrimitive(descriptor)) { + this.originalPrimitiveExitTypeDescriptor = descriptor; + this.exitTypeDescriptor = CodeFlow.toBoxedDescriptor(descriptor); + } + else { + this.exitTypeDescriptor = descriptor; + } } @@ -368,8 +403,7 @@ public class PropertyOrFieldReference extends SpelNodeImpl { this.ref.getValueInternal(this.contextObject, this.evalContext, this.autoGrowNullReferences); PropertyAccessor accessorToUse = this.ref.cachedReadAccessor; if (accessorToUse instanceof CompilablePropertyAccessor) { - this.ref.exitTypeDescriptor = - CodeFlow.toDescriptor(((CompilablePropertyAccessor) accessorToUse).getPropertyType()); + this.ref.setExitTypeDescriptor(CodeFlow.toDescriptor(((CompilablePropertyAccessor) accessorToUse).getPropertyType())); } return value; } diff --git a/spring-expression/src/test/java/org/springframework/expression/spel/SpelCompilationCoverageTests.java b/spring-expression/src/test/java/org/springframework/expression/spel/SpelCompilationCoverageTests.java index f654dbedfb4..03d9bda890f 100644 --- a/spring-expression/src/test/java/org/springframework/expression/spel/SpelCompilationCoverageTests.java +++ b/spring-expression/src/test/java/org/springframework/expression/spel/SpelCompilationCoverageTests.java @@ -703,7 +703,167 @@ public class SpelCompilationCoverageTests extends AbstractExpressionTests { assertCanCompile(expression); assertEquals("def", expression.getValue()); } + + @Test + public void nullsafeFieldPropertyDereferencing_SPR16489() throws Exception { + FooObjectHolder foh = new FooObjectHolder(); + StandardEvaluationContext context = new StandardEvaluationContext(); + context.setRootObject(foh); + // First non compiled: + SpelExpression expression = (SpelExpression) parser.parseExpression("foo?.object"); + assertEquals("hello",expression.getValue(context)); + foh.foo = null; + assertNull(expression.getValue(context)); + + // Now revert state of foh and try compiling it: + foh.foo = new FooObject(); + assertEquals("hello",expression.getValue(context)); + assertCanCompile(expression); + assertEquals("hello",expression.getValue(context)); + foh.foo = null; + assertNull(expression.getValue(context)); + + // Static references + expression = (SpelExpression)parser.parseExpression("#var?.propertya"); + context.setVariable("var", StaticsHelper.class); + assertEquals("sh",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", StaticsHelper.class); + assertEquals("sh",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Single size primitive (boolean) + expression = (SpelExpression)parser.parseExpression("#var?.a"); + context.setVariable("var", new TestClass4()); + assertFalse((Boolean)expression.getValue(context)); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", new TestClass4()); + assertFalse((Boolean)expression.getValue(context)); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Double slot primitives + expression = (SpelExpression)parser.parseExpression("#var?.four"); + context.setVariable("var", new Three()); + assertEquals("0.04",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", new Three()); + assertEquals("0.04",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + } + + @Test + public void nullsafeMethodChaining_SPR16489() throws Exception { + FooObjectHolder foh = new FooObjectHolder(); + StandardEvaluationContext context = new StandardEvaluationContext(); + context.setRootObject(foh); + + // First non compiled: + SpelExpression expression = (SpelExpression) parser.parseExpression("getFoo()?.getObject()"); + assertEquals("hello",expression.getValue(context)); + foh.foo = null; + assertNull(expression.getValue(context)); + assertCanCompile(expression); + foh.foo = new FooObject(); + assertEquals("hello",expression.getValue(context)); + foh.foo = null; + assertNull(expression.getValue(context)); + + // Static method references + expression = (SpelExpression)parser.parseExpression("#var?.methoda()"); + context.setVariable("var", StaticsHelper.class); + assertEquals("sh",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", StaticsHelper.class); + assertEquals("sh",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.intValue()"); + context.setVariable("var", 4); + assertEquals("4",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", 4); + assertEquals("4",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.booleanValue()"); + context.setVariable("var", false); + assertEquals("false",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", false); + assertEquals("false",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.booleanValue()"); + context.setVariable("var", true); + assertEquals("true",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", true); + assertEquals("true",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.longValue()"); + context.setVariable("var", 5L); + assertEquals("5",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", 5L); + assertEquals("5",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.floatValue()"); + context.setVariable("var", 3f); + assertEquals("3.0",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", 3f); + assertEquals("3.0",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.shortValue()"); + context.setVariable("var", (short)8); + assertEquals("8",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", (short)8); + assertEquals("8",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + } + @Test public void elvis() throws Exception { Expression expression = parser.parseExpression("'a'?:'b'"); @@ -3065,19 +3225,47 @@ public class SpelCompilationCoverageTests extends AbstractExpressionTests { assertEquals(1.0f, expression.getValue()); } + @Test + public void compilationOfBasicNullSafeMethodReference() { + SpelExpressionParser parser = new SpelExpressionParser( + new SpelParserConfiguration(SpelCompilerMode.OFF, getClass().getClassLoader())); + SpelExpression expression = parser.parseRaw("#it?.equals(3)"); + StandardEvaluationContext context = new StandardEvaluationContext(new Object[] {1}); + context.setVariable("it", 3); + expression.setEvaluationContext(context); + assertTrue(expression.getValue(Boolean.class)); + context.setVariable("it", null); + assertNull(expression.getValue(Boolean.class)); + + assertCanCompile(expression); + + context.setVariable("it", 3); + assertTrue(expression.getValue(Boolean.class)); + context.setVariable("it", null); + assertNull(expression.getValue(Boolean.class)); + } + @Test public void failsWhenSettingContextForExpression_SPR12326() { SpelExpressionParser parser = new SpelExpressionParser( - new SpelParserConfiguration(SpelCompilerMode.IMMEDIATE, getClass().getClassLoader())); + new SpelParserConfiguration(SpelCompilerMode.OFF, getClass().getClassLoader())); Person3 person = new Person3("foo", 1); SpelExpression expression = parser.parseRaw("#it?.age?.equals([0])"); StandardEvaluationContext context = new StandardEvaluationContext(new Object[] {1}); context.setVariable("it", person); expression.setEvaluationContext(context); assertTrue(expression.getValue(Boolean.class)); + // This will trigger compilation (second usage) assertTrue(expression.getValue(Boolean.class)); + context.setVariable("it", null); + assertNull(expression.getValue(Boolean.class)); + assertCanCompile(expression); + + context.setVariable("it", person); assertTrue(expression.getValue(Boolean.class)); + context.setVariable("it", null); + assertNull(expression.getValue(Boolean.class)); } @@ -5100,6 +5288,14 @@ public class SpelCompilationCoverageTests extends AbstractExpressionTests { } } + public static class FooObjectHolder { + + private FooObject foo = new FooObject(); + + public FooObject getFoo() { + return foo; + } + } public static class FooObject {