This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch GROOVY-6954
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit fd171dacd3a7ef3c0dbfc3da456cb8639a3a942e
Author: Eric Milles <[email protected]>
AuthorDate: Thu Oct 21 17:14:14 2021 -0500

    GROOVY-6954: SC: optimize `map.foo = 'bar'` using java.util.Map#put(K,V)
---
 .../classgen/asm/sc/StaticTypesCallSiteWriter.java | 124 +++++++++++++--------
 src/test/groovy/bugs/Groovy6954.groovy             |  78 +++++++++++++
 .../classgen/asm/AbstractBytecodeTestCase.groovy   |  11 +-
 3 files changed, 163 insertions(+), 50 deletions(-)

diff --git 
a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesCallSiteWriter.java
 
b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesCallSiteWriter.java
index 3958dee..016899f 100644
--- 
a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesCallSiteWriter.java
+++ 
b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesCallSiteWriter.java
@@ -108,6 +108,7 @@ import static org.objectweb.asm.Opcodes.GETSTATIC;
 import static org.objectweb.asm.Opcodes.GOTO;
 import static org.objectweb.asm.Opcodes.IFEQ;
 import static org.objectweb.asm.Opcodes.IFNONNULL;
+import static org.objectweb.asm.Opcodes.IFNULL;
 import static org.objectweb.asm.Opcodes.INVOKEINTERFACE;
 import static org.objectweb.asm.Opcodes.INVOKESPECIAL;
 import static org.objectweb.asm.Opcodes.INVOKESTATIC;
@@ -261,43 +262,40 @@ public class StaticTypesCallSiteWriter extends 
CallSiteWriter {
         );
         call.setImplicitThis(false);
         call.setMethodTarget(target);
-        call.setSafe(false);
         call.visit(controller.getAcg());
     }
 
     private void writeMapDotProperty(final Expression receiver, final String 
propertyName, final boolean safe) {
-        receiver.visit(controller.getAcg()); // load receiver
-
         MethodVisitor mv = controller.getMethodVisitor();
 
-        Label exit = new Label();
+        // push receiver on stack
+        receiver.visit(controller.getAcg());
+
+        // check if receiver null
+        Label skip = new Label();
         if (safe) {
-            Label doGet = new Label();
-            mv.visitJumpInsn(IFNONNULL, doGet);
-            controller.getOperandStack().remove(1);
-            mv.visitInsn(ACONST_NULL);
-            mv.visitJumpInsn(GOTO, exit);
-            mv.visitLabel(doGet);
-            receiver.visit(controller.getAcg());
+            mv.visitInsn(DUP);
+            mv.visitJumpInsn(IFNULL, skip);
         }
 
-        mv.visitLdcInsn(propertyName); // load property name
+        mv.visitLdcInsn(propertyName);
         mv.visitMethodInsn(INVOKEINTERFACE, "java/util/Map", "get", 
"(Ljava/lang/Object;)Ljava/lang/Object;", true);
+
         if (safe) {
-            mv.visitLabel(exit);
+            mv.visitLabel(skip);
         }
         controller.getOperandStack().replace(OBJECT_TYPE);
     }
 
     private void writeListDotProperty(final Expression receiver, final String 
propertyName, final boolean safe) {
+        // for lists, replace list.foo with:
+        //   def result = new ArrayList(list.size())
+        //   for (item in list) result.add(item.foo)
+        //   result
         ClassNode componentType = 
receiver.getNodeMetaData(StaticCompilationMetadataKeys.COMPONENT_TYPE);
         if (componentType == null) {
             componentType = OBJECT_TYPE;
         }
-        // for lists, replace list.foo with:
-        // def result = new ArrayList(list.size())
-        // for (e in list) { result.add (e.foo) }
-        // result
         CompileStack compileStack = controller.getCompileStack();
         MethodVisitor mv = controller.getMethodVisitor();
 
@@ -527,44 +525,44 @@ public class StaticTypesCallSiteWriter extends 
CallSiteWriter {
         FieldNode field = getField(receiverType, fieldName); // GROOVY-7039: 
include interface constants
         if (field != null && 
AsmClassGenerator.isFieldDirectlyAccessible(field, controller.getClassNode())) {
             CompileStack compileStack = controller.getCompileStack();
-            MethodVisitor mv = controller.getMethodVisitor();
-            ClassNode replacementType = field.getOriginType();
             OperandStack operandStack = controller.getOperandStack();
+            MethodVisitor mv = controller.getMethodVisitor();
+            ClassNode resultType = field.getOriginType();
             if (field.isStatic()) {
-                mv.visitFieldInsn(GETSTATIC, 
BytecodeHelper.getClassInternalName(receiverType), fieldName, 
BytecodeHelper.getTypeDescription(replacementType));
-                operandStack.push(replacementType);
+                mv.visitFieldInsn(GETSTATIC, 
BytecodeHelper.getClassInternalName(receiverType), fieldName, 
BytecodeHelper.getTypeDescription(resultType));
+                operandStack.push(resultType);
             } else {
                 if (implicitThis) {
-                    compileStack.pushImplicitThis(implicitThis);
+                    compileStack.pushImplicitThis(true);
                     receiver.visit(controller.getAcg());
                     compileStack.popImplicitThis();
                 } else {
                     receiver.visit(controller.getAcg());
                 }
-                Label exit = new Label();
+                Label skip = new Label();
                 if (safe) {
                     mv.visitInsn(DUP);
                     Label doGet = new Label();
                     mv.visitJumpInsn(IFNONNULL, doGet);
                     mv.visitInsn(POP);
                     mv.visitInsn(ACONST_NULL);
-                    mv.visitJumpInsn(GOTO, exit);
+                    mv.visitJumpInsn(GOTO, skip);
                     mv.visitLabel(doGet);
                 }
                 if 
(!operandStack.getTopOperand().isDerivedFrom(field.getOwner())) {
                     mv.visitTypeInsn(CHECKCAST, 
BytecodeHelper.getClassInternalName(field.getOwner()));
                 }
-                mv.visitFieldInsn(GETFIELD, 
BytecodeHelper.getClassInternalName(field.getOwner()), fieldName, 
BytecodeHelper.getTypeDescription(replacementType));
+                mv.visitFieldInsn(GETFIELD, 
BytecodeHelper.getClassInternalName(field.getOwner()), fieldName, 
BytecodeHelper.getTypeDescription(resultType));
                 if (safe) {
-                    if (ClassHelper.isPrimitiveType(replacementType)) {
-                        operandStack.replace(replacementType);
+                    if (isPrimitiveType(resultType)) {
+                        operandStack.replace(resultType);
                         operandStack.box();
-                        replacementType = operandStack.getTopOperand();
+                        resultType = operandStack.getTopOperand();
                     }
-                    mv.visitLabel(exit);
+                    mv.visitLabel(skip);
                 }
             }
-            operandStack.replace(replacementType);
+            operandStack.replace(resultType);
             return true;
         }
         return false;
@@ -753,7 +751,7 @@ public class StaticTypesCallSiteWriter extends 
CallSiteWriter {
         } else {
             mv.visitMethodInsn(INVOKESTATIC, 
"org/codehaus/groovy/runtime/DefaultGroovyMethods", "power", 
"(Ljava/lang/Number;Ljava/lang/Number;)Ljava/lang/Number;", false);
         }
-        controller.getOperandStack().replace(Number_TYPE, m2 - m1);
+        operandStack.replace(Number_TYPE, m2 - m1);
     }
 
     private void writeStringPlusCall(final Expression receiver, final String 
message, final Expression arguments) {
@@ -766,7 +764,7 @@ public class StaticTypesCallSiteWriter extends 
CallSiteWriter {
         int m2 = operandStack.getStackLength();
         MethodVisitor mv = controller.getMethodVisitor();
         mv.visitMethodInsn(INVOKESTATIC, 
"org/codehaus/groovy/runtime/DefaultGroovyMethods", "plus", 
"(Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/String;", false);
-        controller.getOperandStack().replace(STRING_TYPE, m2 - m1);
+        operandStack.replace(STRING_TYPE, m2 - m1);
     }
 
     private void writeNumberNumberCall(final Expression receiver, final String 
message, final Expression arguments) {
@@ -774,22 +772,27 @@ public class StaticTypesCallSiteWriter extends 
CallSiteWriter {
         int m1 = operandStack.getStackLength();
         // slow path
         prepareSiteAndReceiver(receiver, message, false, 
controller.getCompileStack().isLHS());
-        controller.getOperandStack().doGroovyCast(Number_TYPE);
+        operandStack.doGroovyCast(Number_TYPE);
         visitBoxedArgument(arguments);
-        controller.getOperandStack().doGroovyCast(Number_TYPE);
+        operandStack.doGroovyCast(Number_TYPE);
         int m2 = operandStack.getStackLength();
         MethodVisitor mv = controller.getMethodVisitor();
         mv.visitMethodInsn(INVOKESTATIC, 
"org/codehaus/groovy/runtime/dgmimpl/NumberNumber" + capitalize(message), 
message, "(Ljava/lang/Number;Ljava/lang/Number;)Ljava/lang/Number;", false);
-        controller.getOperandStack().replace(Number_TYPE, m2 - m1);
+        operandStack.replace(Number_TYPE, m2 - m1);
     }
 
     @Override
     public void fallbackAttributeOrPropertySite(final PropertyExpression 
expression, final Expression objectExpression, final String name, final 
MethodCallerMultiAdapter adapter) {
-        if (name != null && controller.getCompileStack().isLHS()) {
-            ClassNode receiverType = getPropertyOwnerType(objectExpression);
+        CompileStack compileStack = controller.getCompileStack();
+        OperandStack operandStack = controller.getOperandStack();
+
+        if (name != null && compileStack.isLHS()) {
+            boolean[] isClassReceiver = new boolean[1];
+            ClassNode receiverType = getPropertyOwnerType(objectExpression, 
isClassReceiver);
             if (adapter == AsmClassGenerator.setField || adapter == 
AsmClassGenerator.setGroovyObjectField) {
                 if (setField(expression, objectExpression, receiverType, 
name)) return;
-            } else if (isThisExpression(objectExpression)) {
+            }
+            if (isThisExpression(objectExpression)) {
                 ClassNode classNode = controller.getClassNode();
                 FieldNode fieldNode = receiverType.getField(name);
                 if (fieldNode != null && fieldNode.isPrivate() && 
!receiverType.equals(classNode)
@@ -798,9 +801,9 @@ public class StaticTypesCallSiteWriter extends 
CallSiteWriter {
                     if (mutators != null) {
                         MethodNode methodNode = mutators.get(name);
                         if (methodNode != null) {
-                            ClassNode rhsType = 
controller.getOperandStack().getTopOperand();
-                            int i = 
controller.getCompileStack().defineTemporaryVariable("$rhsValue", rhsType, 
true);
-                            VariableSlotLoader rhsValue = new 
VariableSlotLoader(rhsType, i, controller.getOperandStack());
+                            ClassNode rhsType = operandStack.getTopOperand();
+                            int i = 
compileStack.defineTemporaryVariable("$rhs", rhsType, true);
+                            VariableSlotLoader rhsValue = new 
VariableSlotLoader(rhsType, i, operandStack);
 
                             MethodCallExpression call = 
callX(objectExpression, methodNode.getName(), args(fieldNode.isStatic() ? 
nullX() : objectExpression, rhsValue));
                             call.setImplicitThis(expression.isImplicitThis());
@@ -810,16 +813,49 @@ public class StaticTypesCallSiteWriter extends 
CallSiteWriter {
                             call.visit(controller.getAcg());
 
                             // GROOVY-9892: assuming that the mutator method 
has a return value, make sure the operand
-                            //  stack is not polluted with the result of the 
method call
-                            controller.getOperandStack().pop();
+                            // stack is not polluted with the result of the 
method call
+                            operandStack.pop();
 
-                            controller.getCompileStack().removeVar(i);
+                            compileStack.removeVar(i);
                             return;
                         }
                     }
                 }
             }
+            if (isOrImplements(receiverType, MAP_TYPE) && !isClassReceiver[0]) 
{
+                MethodVisitor mv = controller.getMethodVisitor();
+
+                // store value in temporary variable
+                ClassNode rhsType = operandStack.getTopOperand();
+                int rhs = compileStack.defineTemporaryVariable("$rhs", 
rhsType, true);
+
+                // push receiver on stack
+                compileStack.pushLHS(false);
+                objectExpression.visit(controller.getAcg());
+                compileStack.popLHS();
+
+                // check if receiver null
+                Label skip = new Label();
+                if (expression.isSafe()) {
+                    mv.visitInsn(DUP);
+                    mv.visitJumpInsn(IFNULL, skip);
+                }
+
+                mv.visitLdcInsn(name);
+                BytecodeHelper.load(mv, rhsType, rhs);
+                if (isPrimitiveType(rhsType)) 
BytecodeHelper.doCastToWrappedType(mv, rhsType, getWrapper(rhsType));
+                mv.visitMethodInsn(INVOKEINTERFACE, "java/util/Map", "put", 
"(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", true);
+
+                if (expression.isSafe()) {
+                    mv.visitLabel(skip);
+                }
+                // no return value
+                operandStack.pop();
+                compileStack.removeVar(rhs);
+                return;
+            }
         }
+
         super.fallbackAttributeOrPropertySite(expression, objectExpression, 
name, adapter);
     }
 
diff --git a/src/test/groovy/bugs/Groovy6954.groovy 
b/src/test/groovy/bugs/Groovy6954.groovy
new file mode 100644
index 0000000..9e9830f
--- /dev/null
+++ b/src/test/groovy/bugs/Groovy6954.groovy
@@ -0,0 +1,78 @@
+/*
+ *  Licensed to the Apache Software Foundation (ASF) under one
+ *  or more contributor license agreements.  See the NOTICE file
+ *  distributed with this work for additional information
+ *  regarding copyright ownership.  The ASF licenses this file
+ *  to you under the Apache License, Version 2.0 (the
+ *  "License"); you may not use this file except in compliance
+ *  with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing,
+ *  software distributed under the License is distributed on an
+ *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ *  KIND, either express or implied.  See the License for the
+ *  specific language governing permissions and limitations
+ *  under the License.
+ */
+package groovy.bugs
+
+import org.codehaus.groovy.classgen.asm.AbstractBytecodeTestCase
+
+final class Groovy6954 extends AbstractBytecodeTestCase {
+
+    void testSetMapDotProperty() {
+        extractionOptions.method = 'put'
+
+        assertScript '''
+            @groovy.transform.CompileStatic
+            def put(Map<String, ?> map) {
+                if (false) map.boo = -1
+                map.foo = 'bar'
+            }
+            def map = [:]
+            assert put(map) == 'bar'
+            assert map.foo  == 'bar'
+            assert !map.containsKey('boo')
+        '''
+
+        assert sequence.hasSequence([
+            'INVOKESTATIC java/lang/Integer.valueOf (I)Ljava/lang/Integer;', 
// boxing -1
+            'INVOKEINTERFACE java/util/Map.put ' // not 
ScriptBytecodeAdapter.setProperty
+        ], sequence.indexOf('--BEGIN--'))
+    }
+
+    void testSafeSetMapDotProperty() {
+        extractionOptions.method = 'put'
+
+        assertScript '''
+            @groovy.transform.CompileStatic
+            def put(Map<String, ?> map) {
+                map?.foo = 'bar'
+            }
+            assert put(null) == 'bar'
+        '''
+
+        assert sequence.hasStrictSequence([
+            'IFNULL L1',
+            'LDC "foo"',
+            'ALOAD 3',
+            'INVOKEINTERFACE java/util/Map.put ',
+            'L1'
+        ])
+    }
+
+    void testChainSetMapDotProperty() {
+        assertScript '''
+            @groovy.transform.CompileStatic
+            def put(Map<String, ?> map) {
+                map.foo = map.bar = 'baz'
+            }
+            def map = [:]
+            assert put(map) == 'baz'
+            assert map.foo  == 'baz'
+            assert map.bar  == 'baz'
+        '''
+    }
+}
diff --git 
a/src/test/org/codehaus/groovy/classgen/asm/AbstractBytecodeTestCase.groovy 
b/src/test/org/codehaus/groovy/classgen/asm/AbstractBytecodeTestCase.groovy
index a2e0850..6c8f9c5 100644
--- a/src/test/org/codehaus/groovy/classgen/asm/AbstractBytecodeTestCase.groovy
+++ b/src/test/org/codehaus/groovy/classgen/asm/AbstractBytecodeTestCase.groovy
@@ -48,21 +48,20 @@ abstract class AbstractBytecodeTestCase extends 
GroovyTestCase {
 
     @Override
     protected void assertScript(final String script) throws Exception {
-        GroovyShell shell = new GroovyShell()
-        def unit
-        shell.loader = new GroovyClassLoader() {
+        CompilationUnit unit = null
+        GroovyShell shell = new GroovyShell(new GroovyClassLoader() {
             @Override
             protected CompilationUnit createCompilationUnit(final 
CompilerConfiguration config, final CodeSource source) {
                 unit = super.createCompilationUnit(config, source)
             }
-        }
+        })
         try {
             shell.evaluate(script, testClassName)
         } finally {
-            if (unit) {
+            if (unit != null) {
                 try {
                     sequence = extractSequence(unit.classes[0].bytes, 
extractionOptions)
-                    if (extractionOptions.print) println sequence
+                    if (extractionOptions.print) println(sequence)
                 } catch (e) {
                     // probably an error in the script
                 }

Reply via email to