This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new d50a84f3df GROOVY-11591: short-circuit safe method call for case of 
`null` receiver
d50a84f3df is described below

commit d50a84f3df92559850834abac9504c8265eaa1d4
Author: Eric Milles <[email protected]>
AuthorDate: Sat Mar 29 16:25:01 2025 -0500

    GROOVY-11591: short-circuit safe method call for case of `null` receiver
---
 .../classgen/asm/indy/InvokeDynamicWriter.java     |  13 ++-
 src/spec/test/OperatorsTest.groovy                 | 109 ++++++++++++++-------
 src/test-resources/core/SafeChainOperator.groovy   |  21 ++--
 3 files changed, 99 insertions(+), 44 deletions(-)

diff --git 
a/src/main/java/org/codehaus/groovy/classgen/asm/indy/InvokeDynamicWriter.java 
b/src/main/java/org/codehaus/groovy/classgen/asm/indy/InvokeDynamicWriter.java
index 53d46d2c85..4316ddc5a7 100644
--- 
a/src/main/java/org/codehaus/groovy/classgen/asm/indy/InvokeDynamicWriter.java
+++ 
b/src/main/java/org/codehaus/groovy/classgen/asm/indy/InvokeDynamicWriter.java
@@ -36,6 +36,7 @@ import org.codehaus.groovy.classgen.asm.WriterController;
 import org.codehaus.groovy.runtime.wrappers.Wrapper;
 import org.codehaus.groovy.vmplugin.v8.IndyInterface;
 import org.objectweb.asm.Handle;
+import org.objectweb.asm.Label;
 import org.objectweb.asm.Opcodes;
 
 import java.lang.invoke.CallSite;
@@ -48,6 +49,7 @@ import static org.codehaus.groovy.ast.ClassHelper.OBJECT_TYPE;
 import static org.codehaus.groovy.ast.ClassHelper.boolean_TYPE;
 import static org.codehaus.groovy.ast.ClassHelper.getWrapper;
 import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveBoolean;
+import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveType;
 import static org.codehaus.groovy.ast.ClassHelper.isWrapperBoolean;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.bytecodeX;
 import static org.codehaus.groovy.classgen.asm.BytecodeHelper.doCast;
@@ -63,6 +65,7 @@ import static 
org.codehaus.groovy.vmplugin.v8.IndyInterface.CallType.INIT;
 import static org.codehaus.groovy.vmplugin.v8.IndyInterface.CallType.INTERFACE;
 import static org.codehaus.groovy.vmplugin.v8.IndyInterface.CallType.METHOD;
 import static org.objectweb.asm.Opcodes.H_INVOKESTATIC;
+import static org.objectweb.asm.Opcodes.IFNULL;
 
 /**
  * This Writer is used to generate the call invocation byte codes
@@ -132,6 +135,12 @@ public class InvokeDynamicWriter extends InvocationWriter {
         Expression receiver = correctReceiverForInterfaceCall(origReceiver, 
operandStack);
         StringBuilder sig = new StringBuilder(prepareIndyCall(receiver, 
implicitThis));
 
+        Label end = null;
+        if (safe && !isPrimitiveType(operandStack.getTopOperand())) {
+            operandStack.dup();
+            end = operandStack.jump(IFNULL);
+        }
+
         // load arguments
         int numberOfArguments = 1;
         List<Expression> args = makeArgumentList(arguments).getExpressions();
@@ -161,9 +170,11 @@ public class InvokeDynamicWriter extends InvocationWriter {
         // receiver != origReceiver interface default method call
         if (receiver != origReceiver) callSiteName = 
INTERFACE.getCallSiteName();
 
-        int flags = getMethodCallFlags(adapter, safe, 
containsSpreadExpression);
+        int flags = getMethodCallFlags(adapter, false, 
containsSpreadExpression);
 
         finishIndyCall(BSM, callSiteName, sig.toString(), numberOfArguments, 
methodName, flags);
+
+        if (end != null) controller.getMethodVisitor().visitLabel(end);
     }
 
     private Expression correctReceiverForInterfaceCall(Expression exp, 
OperandStack operandStack) {
diff --git a/src/spec/test/OperatorsTest.groovy 
b/src/spec/test/OperatorsTest.groovy
index c12185674d..09e3b2b889 100644
--- a/src/spec/test/OperatorsTest.groovy
+++ b/src/spec/test/OperatorsTest.groovy
@@ -16,13 +16,16 @@
  *  specific language governing permissions and limitations
  *  under the License.
  */
-import gls.CompilableTestSupport
+import org.junit.jupiter.api.Test
 
 import java.util.regex.Matcher
 import java.util.regex.Pattern
 
-class OperatorsTest extends CompilableTestSupport {
+import static groovy.test.GroovyAssert.assertScript
 
+final class OperatorsTest {
+
+    @Test
     void testArithmeticOperators() {
         // tag::binary_arith_ops[]
         assert  1  + 2 == 3
@@ -63,6 +66,7 @@ class OperatorsTest extends CompilableTestSupport {
         // end::plusplus_minusminus[]
     }
 
+    @Test
     void testArithmeticOperatorsWithAssignment() {
         // tag::binary_assign_operators[]
         def a = 4
@@ -97,6 +101,7 @@ class OperatorsTest extends CompilableTestSupport {
         // end::binary_assign_operators[]
     }
 
+    @Test
     void testSimpleRelationalOperators() {
         // tag::simple_relational_op[]
         assert 1 + 2 == 3
@@ -111,6 +116,7 @@ class OperatorsTest extends CompilableTestSupport {
         // end::simple_relational_op[]
     }
 
+    @Test
     void testLogicalOperators() {
         // tag::logical_op[]
         assert !false           // <1>
@@ -119,6 +125,7 @@ class OperatorsTest extends CompilableTestSupport {
         // end::logical_op[]
     }
 
+    @Test
     void testBitwiseOperators() {
         // tag::bitwise_op[]
         int a = 0b00101010
@@ -137,6 +144,7 @@ class OperatorsTest extends CompilableTestSupport {
         // end::bitwise_op[]
     }
 
+    @Test
     void testBitShiftOperators() {
         // tag::bit_shift_op[]
         assert 12.equals(3 << 2)           // <1>
@@ -150,6 +158,7 @@ class OperatorsTest extends CompilableTestSupport {
         // end::bit_shift_op[]
     }
 
+    @Test
     void testLogicalOperatorPrecedence() {
         // tag::logical_precendence_1[]
         assert (!false && false) == false   // <1>
@@ -160,32 +169,34 @@ class OperatorsTest extends CompilableTestSupport {
         // end::logical_precendence_2[]
     }
 
+    @Test
     void testLogicalShortCircuit() {
         assertScript '''
-               // tag::logical_shortcircuit[]
-               boolean checkIfCalled() {   // <1>
-                   called = true
-               }
-
-               called = false
-               true || checkIfCalled()
-               assert !called              // <2>
-
-               called = false
-               false || checkIfCalled()
-               assert called               // <3>
-
-               called = false
-               false && checkIfCalled()
-               assert !called              // <4>
-
-               called = false
-               true && checkIfCalled()
-               assert called               // <5>
-               // end::logical_shortcircuit[]
+            // tag::logical_shortcircuit[]
+            boolean checkIfCalled() {   // <1>
+                called = true
+            }
+
+            called = false
+            true || checkIfCalled()
+            assert !called              // <2>
+
+            called = false
+            false || checkIfCalled()
+            assert called               // <3>
+
+            called = false
+            false && checkIfCalled()
+            assert !called              // <4>
+
+            called = false
+            true && checkIfCalled()
+            assert called               // <5>
+            // end::logical_shortcircuit[]
         '''
     }
 
+    @Test
     void testConditionalOperators() {
         // tag::conditional_op_not[]
         assert (!true)    == false                      // <1>
@@ -219,20 +230,32 @@ class OperatorsTest extends CompilableTestSupport {
         displayName = user.name ? user.name : 'Anonymous'   // <1>
         displayName = user.name ?: 'Anonymous'              // <2>
         // end::conditional_op_elvis[]
+    }
 
+    private record Person(long id, String name) {
+        static Person find(Closure<?> c) { null }
     }
 
+    @Test
     void testNullSafeOperator() {
         // tag::nullsafe[]
         def person = Person.find { it.id == 123 }    // <1>
         def name = person?.name                      // <2>
         assert name == null                          // <3>
         // end::nullsafe[]
-    }
 
-    OperatorsTest() {
+        // GROOVY-11591
+        boolean called
+        def f = { -> called = true }
+        def obj = null
+        assert !called
+        obj?.grep(f())
+        assert !called
+        obj?[f()]
+        assert !called
     }
 
+    @Test
     void testDirectFieldAccess() {
         assertScript '''
 // tag::direct_field_class[]
@@ -250,6 +273,7 @@ assert user.@name == 'Bob'                   // <1>
 '''
     }
 
+    @Test
     void testMethodPointer() {
         // tag::method_pointer[]
         def str = 'example of method reference'            // <1>
@@ -310,6 +334,7 @@ assert user.@name == 'Bob'                   // <1>
         '''
     }
 
+    @Test
     void testMethodReference() {
         assertScript '''
             // tag::method_refs[]
@@ -344,6 +369,7 @@ assert user.@name == 'Bob'                   // <1>
         '''
     }
 
+    @Test
     void testRegularExpressionOperators() {
         def pattern = 'foo'
         // tag::pattern_op[]
@@ -396,6 +422,7 @@ assert user.@name == 'Bob'                   // <1>
         // end::pattern_find_vs_matcher[]
     }
 
+    @Test
     void testSpreadDotOperator() {
         assertScript '''
 // tag::spreaddot[]
@@ -418,6 +445,7 @@ assert cars*.make == ['Peugeot', null, 'Renault']     // <2>
 assert null*.make == null                             // <3>
 // end::spreaddot_nullsafe[]
 '''
+
         assertScript '''
 // tag::spreaddot_iterable[]
 class Component {
@@ -439,6 +467,7 @@ assert composite*.id == [1,2]
 assert composite*.name == ['Foo','Bar']
 // end::spreaddot_iterable[]
 '''
+
         assertScript '''
 import groovy.transform.Canonical
 
@@ -469,6 +498,7 @@ assert models.sum() == ['408', '508', 'Clio', 'Captur'] // 
flatten one level
 assert models.flatten() == ['408', '508', 'Clio', 'Captur'] // flatten all 
levels (one in this case)
 // end::spreaddot_multilevel[]
 '''
+
         assertScript '''
 // tag::spreaddot_alternative[]
 class Car {
@@ -490,6 +520,7 @@ assert models == [['408', '508'], ['Clio', 'Captur']]
 '''
     }
 
+    @Test
     void testSpreadMethodArguments() {
         assertScript '''
 // tag::spreadmethodargs_method[]
@@ -510,6 +541,7 @@ assert function(*args,5,6) == 26
 '''
     }
 
+    @Test
     void testSpreadList() {
         // tag::spread_list[]
         def items = [4,5]                      // <1>
@@ -518,6 +550,7 @@ assert function(*args,5,6) == 26
         // end::spread_list[]
     }
 
+    @Test
     void testSpreadMap() {
         assertScript '''
         // tag::spread_map[]
@@ -534,9 +567,9 @@ assert function(*args,5,6) == 26
         assert map == [a:1, b:2, c:3, d:8]    // <3>
         // end::spread_map_position[]
         '''
-
     }
 
+    @Test
     void testRangeOperator() {
         assertScript '''
         // tag::intrange[]
@@ -549,6 +582,7 @@ assert function(*args,5,6) == 26
         assert (0..5).size() == 6                           // <7>
         // end::intrange[]
         '''
+
         assertScript '''
         // tag::charrange[]
         assert ('a'..'d').collect() == ['a','b','c','d']
@@ -556,6 +590,7 @@ assert function(*args,5,6) == 26
         '''
     }
 
+    @Test
     void testSpaceshipOperator() {
         assertScript '''
         // tag::spaceship[]
@@ -564,9 +599,10 @@ assert function(*args,5,6) == 26
         assert (2 <=> 1) == 1
         assert ('a' <=> 'z') == -1
         // end::spaceship[]
-'''
+        '''
     }
 
+    @Test
     void testSubscriptOperator() {
         assertScript '''
         // tag::subscript_op[]
@@ -608,6 +644,7 @@ assert function(*args,5,6) == 26
         '''
     }
 
+    @Test
     void testMembershipOperator() {
         // tag::membership_op[]
         def list = ['Grace','Rob','Emmy']
@@ -616,6 +653,7 @@ assert function(*args,5,6) == 26
         // end::membership_op[]
     }
 
+    @Test
     void testIdentityOperator() {
         // tag::identity_op[]
         def list1 = ['Groovy 1.8','Groovy 2.0','Groovy 2.3']        // <1>
@@ -626,6 +664,7 @@ assert function(*args,5,6) == 26
         // end::identity_op[]
     }
 
+    @Test
     void testCoercionOperator() {
         try {
             // tag::coerce_op_cast[]
@@ -641,6 +680,7 @@ assert function(*args,5,6) == 26
             // end::coerce_op[]
             assert num == 42
         }
+
         assertScript '''
         // tag::coerce_op_custom[]
         class Identifiable {
@@ -664,12 +704,14 @@ assert function(*args,5,6) == 26
         '''
     }
 
+    @Test
     void testDiamondOperator() {
         // tag::diamond_op[]
         List<String> strings = new LinkedList<>()
         // end::diamond_op[]
     }
 
+    @Test
     void testCallOperator() {
         assertScript '''
         // tag::call_op[]
@@ -686,6 +728,7 @@ assert function(*args,5,6) == 26
         '''
     }
 
+    @Test
     void testOperatorOverloading() {
         assertScript '''
 // tag::operator_overload_class[]
@@ -706,6 +749,8 @@ assert (b1 + b2).size == 15                         // <1>
 // end::operator_overload_op[]
 '''
     }
+
+    @Test
     void testOperatorOverloadingWithDifferentArgumentType() {
         assertScript '''
 class Bucket {
@@ -726,12 +771,7 @@ assert (b1 + 11).size == 15
 '''
     }
 
-    private static class Person {
-        Long id
-        String name
-        static Person find(Closure c) { null }
-    }
-
+    @Test
     void testGStringEquals() {
         assertScript '''
             w = 'world'
@@ -746,6 +786,7 @@ assert (b1 + 11).size == 15
             '''
     }
 
+    @Test
     void testBooleanOr() {
         assertScript '''
 boolean trueValue1 = true, trueValue2 = true, trueValue3 = true
@@ -760,6 +801,7 @@ assert !(falseValue3 |= null)
 '''
     }
 
+    @Test
     void testBooleanAnd() {
         assertScript '''
 boolean trueValue1 = true, trueValue2 = true, trueValue3 = true
@@ -774,6 +816,7 @@ assert !(falseValue3 &= null)
 '''
     }
 
+    @Test
     void testBooleanXor() {
         assertScript '''
 boolean trueValue1 = true, trueValue2 = true, trueValue3 = true
diff --git a/src/test-resources/core/SafeChainOperator.groovy 
b/src/test-resources/core/SafeChainOperator.groovy
index e3784baa65..7fe9bab0d9 100644
--- a/src/test-resources/core/SafeChainOperator.groovy
+++ b/src/test-resources/core/SafeChainOperator.groovy
@@ -1,5 +1,3 @@
-import groovy.transform.CompileStatic
-
 /*
  *  Licensed to the Apache Software Foundation (ASF) under one
  *  or more contributor license agreements.  See the NOTICE file
@@ -18,11 +16,14 @@ import groovy.transform.CompileStatic
  *  specific language governing permissions and limitations
  *  under the License.
  */
+
+import groovy.transform.CompileStatic
+
 def testSCO() {
-    assert 3 == 1??.plus(2)
-    assert 6 == 1??.plus(2).plus(3)
-    assert 6 == 1??.plus(2)?.plus(3)
-    assert 6 == 1??.plus(2)??.plus(3)
+    assert  3 == 1??.plus(2)
+    assert  6 == 1??.plus(2).plus(3)
+    assert  6 == 1??.plus(2)?.plus(3)
+    assert  6 == 1??.plus(2)??.plus(3)
     assert 10 == 1??.plus(2)?.plus(3).plus(4)
     assert 10 == 1?.plus(2)??.plus(3).plus(4)
     assert 10 == 1?.plus(2)?.plus(3)??.plus(4)
@@ -41,10 +42,10 @@ testSCO()
 
 @CompileStatic
 def testCsSCO() {
-    assert 3 == 1??.plus(2)
-    assert 6 == 1??.plus(2).plus(3)
-    assert 6 == 1??.plus(2)?.plus(3)
-    assert 6 == 1??.plus(2)??.plus(3)
+    assert  3 == 1??.plus(2)
+    assert  6 == 1??.plus(2).plus(3)
+    assert  6 == 1??.plus(2)?.plus(3)
+    assert  6 == 1??.plus(2)??.plus(3)
     assert 10 == 1??.plus(2)?.plus(3).plus(4)
     assert 10 == 1?.plus(2)??.plus(3).plus(4)
     assert 10 == 1?.plus(2)?.plus(3)??.plus(4)

Reply via email to