This is an automated email from the ASF dual-hosted git repository. emilles pushed a commit to branch GROOVY-6954 in repository https://gitbox.apache.org/repos/asf/groovy.git
commit fd171dacd3a7ef3c0dbfc3da456cb8639a3a942e Author: Eric Milles <[email protected]> AuthorDate: Thu Oct 21 17:14:14 2021 -0500 GROOVY-6954: SC: optimize `map.foo = 'bar'` using java.util.Map#put(K,V) --- .../classgen/asm/sc/StaticTypesCallSiteWriter.java | 124 +++++++++++++-------- src/test/groovy/bugs/Groovy6954.groovy | 78 +++++++++++++ .../classgen/asm/AbstractBytecodeTestCase.groovy | 11 +- 3 files changed, 163 insertions(+), 50 deletions(-) diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesCallSiteWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesCallSiteWriter.java index 3958dee..016899f 100644 --- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesCallSiteWriter.java +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesCallSiteWriter.java @@ -108,6 +108,7 @@ import static org.objectweb.asm.Opcodes.GETSTATIC; import static org.objectweb.asm.Opcodes.GOTO; import static org.objectweb.asm.Opcodes.IFEQ; import static org.objectweb.asm.Opcodes.IFNONNULL; +import static org.objectweb.asm.Opcodes.IFNULL; import static org.objectweb.asm.Opcodes.INVOKEINTERFACE; import static org.objectweb.asm.Opcodes.INVOKESPECIAL; import static org.objectweb.asm.Opcodes.INVOKESTATIC; @@ -261,43 +262,40 @@ public class StaticTypesCallSiteWriter extends CallSiteWriter { ); call.setImplicitThis(false); call.setMethodTarget(target); - call.setSafe(false); call.visit(controller.getAcg()); } private void writeMapDotProperty(final Expression receiver, final String propertyName, final boolean safe) { - receiver.visit(controller.getAcg()); // load receiver - MethodVisitor mv = controller.getMethodVisitor(); - Label exit = new Label(); + // push receiver on stack + receiver.visit(controller.getAcg()); + + // check if receiver null + Label skip = new Label(); if (safe) { - Label doGet = new Label(); - mv.visitJumpInsn(IFNONNULL, doGet); - controller.getOperandStack().remove(1); - mv.visitInsn(ACONST_NULL); - mv.visitJumpInsn(GOTO, exit); - mv.visitLabel(doGet); - receiver.visit(controller.getAcg()); + mv.visitInsn(DUP); + mv.visitJumpInsn(IFNULL, skip); } - mv.visitLdcInsn(propertyName); // load property name + mv.visitLdcInsn(propertyName); mv.visitMethodInsn(INVOKEINTERFACE, "java/util/Map", "get", "(Ljava/lang/Object;)Ljava/lang/Object;", true); + if (safe) { - mv.visitLabel(exit); + mv.visitLabel(skip); } controller.getOperandStack().replace(OBJECT_TYPE); } private void writeListDotProperty(final Expression receiver, final String propertyName, final boolean safe) { + // for lists, replace list.foo with: + // def result = new ArrayList(list.size()) + // for (item in list) result.add(item.foo) + // result ClassNode componentType = receiver.getNodeMetaData(StaticCompilationMetadataKeys.COMPONENT_TYPE); if (componentType == null) { componentType = OBJECT_TYPE; } - // for lists, replace list.foo with: - // def result = new ArrayList(list.size()) - // for (e in list) { result.add (e.foo) } - // result CompileStack compileStack = controller.getCompileStack(); MethodVisitor mv = controller.getMethodVisitor(); @@ -527,44 +525,44 @@ public class StaticTypesCallSiteWriter extends CallSiteWriter { FieldNode field = getField(receiverType, fieldName); // GROOVY-7039: include interface constants if (field != null && AsmClassGenerator.isFieldDirectlyAccessible(field, controller.getClassNode())) { CompileStack compileStack = controller.getCompileStack(); - MethodVisitor mv = controller.getMethodVisitor(); - ClassNode replacementType = field.getOriginType(); OperandStack operandStack = controller.getOperandStack(); + MethodVisitor mv = controller.getMethodVisitor(); + ClassNode resultType = field.getOriginType(); if (field.isStatic()) { - mv.visitFieldInsn(GETSTATIC, BytecodeHelper.getClassInternalName(receiverType), fieldName, BytecodeHelper.getTypeDescription(replacementType)); - operandStack.push(replacementType); + mv.visitFieldInsn(GETSTATIC, BytecodeHelper.getClassInternalName(receiverType), fieldName, BytecodeHelper.getTypeDescription(resultType)); + operandStack.push(resultType); } else { if (implicitThis) { - compileStack.pushImplicitThis(implicitThis); + compileStack.pushImplicitThis(true); receiver.visit(controller.getAcg()); compileStack.popImplicitThis(); } else { receiver.visit(controller.getAcg()); } - Label exit = new Label(); + Label skip = new Label(); if (safe) { mv.visitInsn(DUP); Label doGet = new Label(); mv.visitJumpInsn(IFNONNULL, doGet); mv.visitInsn(POP); mv.visitInsn(ACONST_NULL); - mv.visitJumpInsn(GOTO, exit); + mv.visitJumpInsn(GOTO, skip); mv.visitLabel(doGet); } if (!operandStack.getTopOperand().isDerivedFrom(field.getOwner())) { mv.visitTypeInsn(CHECKCAST, BytecodeHelper.getClassInternalName(field.getOwner())); } - mv.visitFieldInsn(GETFIELD, BytecodeHelper.getClassInternalName(field.getOwner()), fieldName, BytecodeHelper.getTypeDescription(replacementType)); + mv.visitFieldInsn(GETFIELD, BytecodeHelper.getClassInternalName(field.getOwner()), fieldName, BytecodeHelper.getTypeDescription(resultType)); if (safe) { - if (ClassHelper.isPrimitiveType(replacementType)) { - operandStack.replace(replacementType); + if (isPrimitiveType(resultType)) { + operandStack.replace(resultType); operandStack.box(); - replacementType = operandStack.getTopOperand(); + resultType = operandStack.getTopOperand(); } - mv.visitLabel(exit); + mv.visitLabel(skip); } } - operandStack.replace(replacementType); + operandStack.replace(resultType); return true; } return false; @@ -753,7 +751,7 @@ public class StaticTypesCallSiteWriter extends CallSiteWriter { } else { mv.visitMethodInsn(INVOKESTATIC, "org/codehaus/groovy/runtime/DefaultGroovyMethods", "power", "(Ljava/lang/Number;Ljava/lang/Number;)Ljava/lang/Number;", false); } - controller.getOperandStack().replace(Number_TYPE, m2 - m1); + operandStack.replace(Number_TYPE, m2 - m1); } private void writeStringPlusCall(final Expression receiver, final String message, final Expression arguments) { @@ -766,7 +764,7 @@ public class StaticTypesCallSiteWriter extends CallSiteWriter { int m2 = operandStack.getStackLength(); MethodVisitor mv = controller.getMethodVisitor(); mv.visitMethodInsn(INVOKESTATIC, "org/codehaus/groovy/runtime/DefaultGroovyMethods", "plus", "(Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/String;", false); - controller.getOperandStack().replace(STRING_TYPE, m2 - m1); + operandStack.replace(STRING_TYPE, m2 - m1); } private void writeNumberNumberCall(final Expression receiver, final String message, final Expression arguments) { @@ -774,22 +772,27 @@ public class StaticTypesCallSiteWriter extends CallSiteWriter { int m1 = operandStack.getStackLength(); // slow path prepareSiteAndReceiver(receiver, message, false, controller.getCompileStack().isLHS()); - controller.getOperandStack().doGroovyCast(Number_TYPE); + operandStack.doGroovyCast(Number_TYPE); visitBoxedArgument(arguments); - controller.getOperandStack().doGroovyCast(Number_TYPE); + operandStack.doGroovyCast(Number_TYPE); int m2 = operandStack.getStackLength(); MethodVisitor mv = controller.getMethodVisitor(); mv.visitMethodInsn(INVOKESTATIC, "org/codehaus/groovy/runtime/dgmimpl/NumberNumber" + capitalize(message), message, "(Ljava/lang/Number;Ljava/lang/Number;)Ljava/lang/Number;", false); - controller.getOperandStack().replace(Number_TYPE, m2 - m1); + operandStack.replace(Number_TYPE, m2 - m1); } @Override public void fallbackAttributeOrPropertySite(final PropertyExpression expression, final Expression objectExpression, final String name, final MethodCallerMultiAdapter adapter) { - if (name != null && controller.getCompileStack().isLHS()) { - ClassNode receiverType = getPropertyOwnerType(objectExpression); + CompileStack compileStack = controller.getCompileStack(); + OperandStack operandStack = controller.getOperandStack(); + + if (name != null && compileStack.isLHS()) { + boolean[] isClassReceiver = new boolean[1]; + ClassNode receiverType = getPropertyOwnerType(objectExpression, isClassReceiver); if (adapter == AsmClassGenerator.setField || adapter == AsmClassGenerator.setGroovyObjectField) { if (setField(expression, objectExpression, receiverType, name)) return; - } else if (isThisExpression(objectExpression)) { + } + if (isThisExpression(objectExpression)) { ClassNode classNode = controller.getClassNode(); FieldNode fieldNode = receiverType.getField(name); if (fieldNode != null && fieldNode.isPrivate() && !receiverType.equals(classNode) @@ -798,9 +801,9 @@ public class StaticTypesCallSiteWriter extends CallSiteWriter { if (mutators != null) { MethodNode methodNode = mutators.get(name); if (methodNode != null) { - ClassNode rhsType = controller.getOperandStack().getTopOperand(); - int i = controller.getCompileStack().defineTemporaryVariable("$rhsValue", rhsType, true); - VariableSlotLoader rhsValue = new VariableSlotLoader(rhsType, i, controller.getOperandStack()); + ClassNode rhsType = operandStack.getTopOperand(); + int i = compileStack.defineTemporaryVariable("$rhs", rhsType, true); + VariableSlotLoader rhsValue = new VariableSlotLoader(rhsType, i, operandStack); MethodCallExpression call = callX(objectExpression, methodNode.getName(), args(fieldNode.isStatic() ? nullX() : objectExpression, rhsValue)); call.setImplicitThis(expression.isImplicitThis()); @@ -810,16 +813,49 @@ public class StaticTypesCallSiteWriter extends CallSiteWriter { call.visit(controller.getAcg()); // GROOVY-9892: assuming that the mutator method has a return value, make sure the operand - // stack is not polluted with the result of the method call - controller.getOperandStack().pop(); + // stack is not polluted with the result of the method call + operandStack.pop(); - controller.getCompileStack().removeVar(i); + compileStack.removeVar(i); return; } } } } + if (isOrImplements(receiverType, MAP_TYPE) && !isClassReceiver[0]) { + MethodVisitor mv = controller.getMethodVisitor(); + + // store value in temporary variable + ClassNode rhsType = operandStack.getTopOperand(); + int rhs = compileStack.defineTemporaryVariable("$rhs", rhsType, true); + + // push receiver on stack + compileStack.pushLHS(false); + objectExpression.visit(controller.getAcg()); + compileStack.popLHS(); + + // check if receiver null + Label skip = new Label(); + if (expression.isSafe()) { + mv.visitInsn(DUP); + mv.visitJumpInsn(IFNULL, skip); + } + + mv.visitLdcInsn(name); + BytecodeHelper.load(mv, rhsType, rhs); + if (isPrimitiveType(rhsType)) BytecodeHelper.doCastToWrappedType(mv, rhsType, getWrapper(rhsType)); + mv.visitMethodInsn(INVOKEINTERFACE, "java/util/Map", "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", true); + + if (expression.isSafe()) { + mv.visitLabel(skip); + } + // no return value + operandStack.pop(); + compileStack.removeVar(rhs); + return; + } } + super.fallbackAttributeOrPropertySite(expression, objectExpression, name, adapter); } diff --git a/src/test/groovy/bugs/Groovy6954.groovy b/src/test/groovy/bugs/Groovy6954.groovy new file mode 100644 index 0000000..9e9830f --- /dev/null +++ b/src/test/groovy/bugs/Groovy6954.groovy @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package groovy.bugs + +import org.codehaus.groovy.classgen.asm.AbstractBytecodeTestCase + +final class Groovy6954 extends AbstractBytecodeTestCase { + + void testSetMapDotProperty() { + extractionOptions.method = 'put' + + assertScript ''' + @groovy.transform.CompileStatic + def put(Map<String, ?> map) { + if (false) map.boo = -1 + map.foo = 'bar' + } + def map = [:] + assert put(map) == 'bar' + assert map.foo == 'bar' + assert !map.containsKey('boo') + ''' + + assert sequence.hasSequence([ + 'INVOKESTATIC java/lang/Integer.valueOf (I)Ljava/lang/Integer;', // boxing -1 + 'INVOKEINTERFACE java/util/Map.put ' // not ScriptBytecodeAdapter.setProperty + ], sequence.indexOf('--BEGIN--')) + } + + void testSafeSetMapDotProperty() { + extractionOptions.method = 'put' + + assertScript ''' + @groovy.transform.CompileStatic + def put(Map<String, ?> map) { + map?.foo = 'bar' + } + assert put(null) == 'bar' + ''' + + assert sequence.hasStrictSequence([ + 'IFNULL L1', + 'LDC "foo"', + 'ALOAD 3', + 'INVOKEINTERFACE java/util/Map.put ', + 'L1' + ]) + } + + void testChainSetMapDotProperty() { + assertScript ''' + @groovy.transform.CompileStatic + def put(Map<String, ?> map) { + map.foo = map.bar = 'baz' + } + def map = [:] + assert put(map) == 'baz' + assert map.foo == 'baz' + assert map.bar == 'baz' + ''' + } +} diff --git a/src/test/org/codehaus/groovy/classgen/asm/AbstractBytecodeTestCase.groovy b/src/test/org/codehaus/groovy/classgen/asm/AbstractBytecodeTestCase.groovy index a2e0850..6c8f9c5 100644 --- a/src/test/org/codehaus/groovy/classgen/asm/AbstractBytecodeTestCase.groovy +++ b/src/test/org/codehaus/groovy/classgen/asm/AbstractBytecodeTestCase.groovy @@ -48,21 +48,20 @@ abstract class AbstractBytecodeTestCase extends GroovyTestCase { @Override protected void assertScript(final String script) throws Exception { - GroovyShell shell = new GroovyShell() - def unit - shell.loader = new GroovyClassLoader() { + CompilationUnit unit = null + GroovyShell shell = new GroovyShell(new GroovyClassLoader() { @Override protected CompilationUnit createCompilationUnit(final CompilerConfiguration config, final CodeSource source) { unit = super.createCompilationUnit(config, source) } - } + }) try { shell.evaluate(script, testClassName) } finally { - if (unit) { + if (unit != null) { try { sequence = extractSequence(unit.classes[0].bytes, extractionOptions) - if (extractionOptions.print) println sequence + if (extractionOptions.print) println(sequence) } catch (e) { // probably an error in the script }
