From 1db7e02de3eb0c011ee6681f5a12eb9d166fea81 Mon Sep 17 00:00:00 2001 From: Andy Clement Date: Mon, 12 Mar 2018 11:11:40 -0700 Subject: [PATCH] Modify SpEL code gen to take account of null safe refs With this change the code generation for method and property references is modified to include branching logic in the case of null safe dereferencing (?.). This is complicated by the possible usage of primitives on the left hand side of the dereference. To cope with this case primitives are promoted to boxed types when this situation occurs enabling null to be passed as a possible result. Issue: SPR-16489 --- .../expression/spel/CodeFlow.java | 15 ++ .../expression/spel/ast/MethodReference.java | 46 +++- .../spel/ast/PropertyOrFieldReference.java | 40 +++- .../spel/SpelCompilationCoverageTests.java | 198 +++++++++++++++++- 4 files changed, 284 insertions(+), 15 deletions(-) diff --git a/spring-expression/src/main/java/org/springframework/expression/spel/CodeFlow.java b/spring-expression/src/main/java/org/springframework/expression/spel/CodeFlow.java index b257f47f557..659107ca9b9 100644 --- a/spring-expression/src/main/java/org/springframework/expression/spel/CodeFlow.java +++ b/spring-expression/src/main/java/org/springframework/expression/spel/CodeFlow.java @@ -1017,4 +1017,19 @@ public class CodeFlow implements Opcodes { void generateCode(MethodVisitor mv, CodeFlow codeflow); } + public static String toBoxedDescriptor(String primitiveDescriptor) { + switch (primitiveDescriptor.charAt(0)) { + case 'I': return "Ljava/lang/Integer"; + case 'J': return "Ljava/lang/Long"; + case 'F': return "Ljava/lang/Float"; + case 'D': return "Ljava/lang/Double"; + case 'B': return "Ljava/lang/Byte"; + case 'C': return "Ljava/lang/Character"; + case 'S': return "Ljava/lang/Short"; + case 'Z': return "Ljava/lang/Boolean"; + default: + throw new IllegalArgumentException("Unexpected non primitive descriptor "+primitiveDescriptor); + } + } + } diff --git a/spring-expression/src/main/java/org/springframework/expression/spel/ast/MethodReference.java b/spring-expression/src/main/java/org/springframework/expression/spel/ast/MethodReference.java index 1b68f61ecbb..1c578bac674 100644 --- a/spring-expression/src/main/java/org/springframework/expression/spel/ast/MethodReference.java +++ b/spring-expression/src/main/java/org/springframework/expression/spel/ast/MethodReference.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import org.springframework.asm.Label; import org.springframework.asm.MethodVisitor; import org.springframework.core.convert.TypeDescriptor; import org.springframework.expression.AccessException; @@ -56,6 +57,8 @@ public class MethodReference extends SpelNodeImpl { private final boolean nullSafe; + private String originalPrimitiveExitTypeDescriptor = null; + @Nullable private volatile CachedMethodExecutor cachedExecutor; @@ -236,7 +239,14 @@ public class MethodReference extends SpelNodeImpl { CachedMethodExecutor executorToCheck = this.cachedExecutor; if (executorToCheck != null && executorToCheck.get() instanceof ReflectiveMethodExecutor) { Method method = ((ReflectiveMethodExecutor) executorToCheck.get()).getMethod(); - this.exitTypeDescriptor = CodeFlow.toDescriptor(method.getReturnType()); + String descriptor = CodeFlow.toDescriptor(method.getReturnType()); + if (this.nullSafe && CodeFlow.isPrimitive(descriptor)) { + originalPrimitiveExitTypeDescriptor = descriptor; + this.exitTypeDescriptor = CodeFlow.toBoxedDescriptor(descriptor); + } + else { + this.exitTypeDescriptor = descriptor; + } } } @@ -296,17 +306,23 @@ public class MethodReference extends SpelNodeImpl { boolean isStaticMethod = Modifier.isStatic(method.getModifiers()); String descriptor = cf.lastDescriptor(); - if (descriptor == null) { - if (!isStaticMethod) { - // Nothing on the stack but something is needed - cf.loadTarget(mv); - } + Label skipIfNull = null; + if (descriptor == null && !isStaticMethod) { + // Nothing on the stack but something is needed + cf.loadTarget(mv); } - else { - if (isStaticMethod) { - // Something on the stack when nothing is needed - mv.visitInsn(POP); - } + if ((descriptor != null || !isStaticMethod) && nullSafe) { + mv.visitInsn(DUP); + skipIfNull = new Label(); + Label continueLabel = new Label(); + mv.visitJumpInsn(IFNONNULL,continueLabel); + CodeFlow.insertCheckCast(mv, this.exitTypeDescriptor); + mv.visitJumpInsn(GOTO, skipIfNull); + mv.visitLabel(continueLabel); + } + if (descriptor != null && isStaticMethod) { + // Something on the stack when nothing is needed + mv.visitInsn(POP); } if (CodeFlow.isPrimitive(descriptor)) { @@ -333,6 +349,14 @@ public class MethodReference extends SpelNodeImpl { mv.visitMethodInsn((isStaticMethod ? INVOKESTATIC : INVOKEVIRTUAL), classDesc, method.getName(), CodeFlow.createSignatureDescriptor(method), method.getDeclaringClass().isInterface()); cf.pushDescriptor(this.exitTypeDescriptor); + if (originalPrimitiveExitTypeDescriptor != null) { + // The output of the accessor will be a primitive but from the block above it might be null, + // so to have a 'common stack' element at skipIfNull target we need to box the primitive + CodeFlow.insertBoxIfNecessary(mv, originalPrimitiveExitTypeDescriptor); + } + if (skipIfNull != null) { + mv.visitLabel(skipIfNull); + } } diff --git a/spring-expression/src/main/java/org/springframework/expression/spel/ast/PropertyOrFieldReference.java b/spring-expression/src/main/java/org/springframework/expression/spel/ast/PropertyOrFieldReference.java index cbc74a9459c..3653e93b0c2 100644 --- a/spring-expression/src/main/java/org/springframework/expression/spel/ast/PropertyOrFieldReference.java +++ b/spring-expression/src/main/java/org/springframework/expression/spel/ast/PropertyOrFieldReference.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.springframework.asm.Label; import org.springframework.asm.MethodVisitor; import org.springframework.core.convert.TypeDescriptor; import org.springframework.expression.AccessException; @@ -51,6 +52,8 @@ public class PropertyOrFieldReference extends SpelNodeImpl { private final boolean nullSafe; + private String originalPrimitiveExitTypeDescriptor = null; + private final String name; @Nullable @@ -89,7 +92,7 @@ public class PropertyOrFieldReference extends SpelNodeImpl { PropertyAccessor accessorToUse = this.cachedReadAccessor; if (accessorToUse instanceof CompilablePropertyAccessor) { CompilablePropertyAccessor accessor = (CompilablePropertyAccessor) accessorToUse; - this.exitTypeDescriptor = CodeFlow.toDescriptor(accessor.getPropertyType()); + setExitTypeDescriptor(CodeFlow.toDescriptor(accessor.getPropertyType())); } return tv; } @@ -338,8 +341,40 @@ public class PropertyOrFieldReference extends SpelNodeImpl { if (!(accessorToUse instanceof CompilablePropertyAccessor)) { throw new IllegalStateException("Property accessor is not compilable: " + accessorToUse); } + Label skipIfNull = null; + if (nullSafe) { + mv.visitInsn(DUP); + skipIfNull = new Label(); + Label continueLabel = new Label(); + mv.visitJumpInsn(IFNONNULL,continueLabel); + CodeFlow.insertCheckCast(mv, this.exitTypeDescriptor); + mv.visitJumpInsn(GOTO, skipIfNull); + mv.visitLabel(continueLabel); + } ((CompilablePropertyAccessor) accessorToUse).generateCode(this.name, mv, cf); cf.pushDescriptor(this.exitTypeDescriptor); + if (originalPrimitiveExitTypeDescriptor != null) { + // The output of the accessor is a primitive but from the block above it might be null, + // so to have a common stack element type at skipIfNull target it is necessary + // to box the primitive + CodeFlow.insertBoxIfNecessary(mv, originalPrimitiveExitTypeDescriptor); + } + if (skipIfNull != null) { + mv.visitLabel(skipIfNull); + } + } + + void setExitTypeDescriptor(String descriptor) { + // If this property or field access would return a primitive - and yet + // it is also marked null safe - then the exit type descriptor must be + // promoted to the box type to allow a null value to be passed on + if (this.nullSafe && CodeFlow.isPrimitive(descriptor)) { + this.originalPrimitiveExitTypeDescriptor = descriptor; + this.exitTypeDescriptor = CodeFlow.toBoxedDescriptor(descriptor); + } + else { + this.exitTypeDescriptor = descriptor; + } } @@ -368,8 +403,7 @@ public class PropertyOrFieldReference extends SpelNodeImpl { this.ref.getValueInternal(this.contextObject, this.evalContext, this.autoGrowNullReferences); PropertyAccessor accessorToUse = this.ref.cachedReadAccessor; if (accessorToUse instanceof CompilablePropertyAccessor) { - this.ref.exitTypeDescriptor = - CodeFlow.toDescriptor(((CompilablePropertyAccessor) accessorToUse).getPropertyType()); + this.ref.setExitTypeDescriptor(CodeFlow.toDescriptor(((CompilablePropertyAccessor) accessorToUse).getPropertyType())); } return value; } diff --git a/spring-expression/src/test/java/org/springframework/expression/spel/SpelCompilationCoverageTests.java b/spring-expression/src/test/java/org/springframework/expression/spel/SpelCompilationCoverageTests.java index f654dbedfb4..03d9bda890f 100644 --- a/spring-expression/src/test/java/org/springframework/expression/spel/SpelCompilationCoverageTests.java +++ b/spring-expression/src/test/java/org/springframework/expression/spel/SpelCompilationCoverageTests.java @@ -703,7 +703,167 @@ public class SpelCompilationCoverageTests extends AbstractExpressionTests { assertCanCompile(expression); assertEquals("def", expression.getValue()); } + + @Test + public void nullsafeFieldPropertyDereferencing_SPR16489() throws Exception { + FooObjectHolder foh = new FooObjectHolder(); + StandardEvaluationContext context = new StandardEvaluationContext(); + context.setRootObject(foh); + // First non compiled: + SpelExpression expression = (SpelExpression) parser.parseExpression("foo?.object"); + assertEquals("hello",expression.getValue(context)); + foh.foo = null; + assertNull(expression.getValue(context)); + + // Now revert state of foh and try compiling it: + foh.foo = new FooObject(); + assertEquals("hello",expression.getValue(context)); + assertCanCompile(expression); + assertEquals("hello",expression.getValue(context)); + foh.foo = null; + assertNull(expression.getValue(context)); + + // Static references + expression = (SpelExpression)parser.parseExpression("#var?.propertya"); + context.setVariable("var", StaticsHelper.class); + assertEquals("sh",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", StaticsHelper.class); + assertEquals("sh",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Single size primitive (boolean) + expression = (SpelExpression)parser.parseExpression("#var?.a"); + context.setVariable("var", new TestClass4()); + assertFalse((Boolean)expression.getValue(context)); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", new TestClass4()); + assertFalse((Boolean)expression.getValue(context)); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Double slot primitives + expression = (SpelExpression)parser.parseExpression("#var?.four"); + context.setVariable("var", new Three()); + assertEquals("0.04",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", new Three()); + assertEquals("0.04",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + } + + @Test + public void nullsafeMethodChaining_SPR16489() throws Exception { + FooObjectHolder foh = new FooObjectHolder(); + StandardEvaluationContext context = new StandardEvaluationContext(); + context.setRootObject(foh); + + // First non compiled: + SpelExpression expression = (SpelExpression) parser.parseExpression("getFoo()?.getObject()"); + assertEquals("hello",expression.getValue(context)); + foh.foo = null; + assertNull(expression.getValue(context)); + assertCanCompile(expression); + foh.foo = new FooObject(); + assertEquals("hello",expression.getValue(context)); + foh.foo = null; + assertNull(expression.getValue(context)); + + // Static method references + expression = (SpelExpression)parser.parseExpression("#var?.methoda()"); + context.setVariable("var", StaticsHelper.class); + assertEquals("sh",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", StaticsHelper.class); + assertEquals("sh",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.intValue()"); + context.setVariable("var", 4); + assertEquals("4",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", 4); + assertEquals("4",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.booleanValue()"); + context.setVariable("var", false); + assertEquals("false",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", false); + assertEquals("false",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.booleanValue()"); + context.setVariable("var", true); + assertEquals("true",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", true); + assertEquals("true",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.longValue()"); + context.setVariable("var", 5L); + assertEquals("5",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", 5L); + assertEquals("5",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.floatValue()"); + context.setVariable("var", 3f); + assertEquals("3.0",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", 3f); + assertEquals("3.0",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + + // Nullsafe guard on expression element evaluating to primitive/null + expression = (SpelExpression)parser.parseExpression("#var?.shortValue()"); + context.setVariable("var", (short)8); + assertEquals("8",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + assertCanCompile(expression); + context.setVariable("var", (short)8); + assertEquals("8",expression.getValue(context).toString()); + context.setVariable("var", null); + assertNull(expression.getValue(context)); + } + @Test public void elvis() throws Exception { Expression expression = parser.parseExpression("'a'?:'b'"); @@ -3065,19 +3225,47 @@ public class SpelCompilationCoverageTests extends AbstractExpressionTests { assertEquals(1.0f, expression.getValue()); } + @Test + public void compilationOfBasicNullSafeMethodReference() { + SpelExpressionParser parser = new SpelExpressionParser( + new SpelParserConfiguration(SpelCompilerMode.OFF, getClass().getClassLoader())); + SpelExpression expression = parser.parseRaw("#it?.equals(3)"); + StandardEvaluationContext context = new StandardEvaluationContext(new Object[] {1}); + context.setVariable("it", 3); + expression.setEvaluationContext(context); + assertTrue(expression.getValue(Boolean.class)); + context.setVariable("it", null); + assertNull(expression.getValue(Boolean.class)); + + assertCanCompile(expression); + + context.setVariable("it", 3); + assertTrue(expression.getValue(Boolean.class)); + context.setVariable("it", null); + assertNull(expression.getValue(Boolean.class)); + } + @Test public void failsWhenSettingContextForExpression_SPR12326() { SpelExpressionParser parser = new SpelExpressionParser( - new SpelParserConfiguration(SpelCompilerMode.IMMEDIATE, getClass().getClassLoader())); + new SpelParserConfiguration(SpelCompilerMode.OFF, getClass().getClassLoader())); Person3 person = new Person3("foo", 1); SpelExpression expression = parser.parseRaw("#it?.age?.equals([0])"); StandardEvaluationContext context = new StandardEvaluationContext(new Object[] {1}); context.setVariable("it", person); expression.setEvaluationContext(context); assertTrue(expression.getValue(Boolean.class)); + // This will trigger compilation (second usage) assertTrue(expression.getValue(Boolean.class)); + context.setVariable("it", null); + assertNull(expression.getValue(Boolean.class)); + assertCanCompile(expression); + + context.setVariable("it", person); assertTrue(expression.getValue(Boolean.class)); + context.setVariable("it", null); + assertNull(expression.getValue(Boolean.class)); } @@ -5100,6 +5288,14 @@ public class SpelCompilationCoverageTests extends AbstractExpressionTests { } } + public static class FooObjectHolder { + + private FooObject foo = new FooObject(); + + public FooObject getFoo() { + return foo; + } + } public static class FooObject {