This is an automated email from the ASF dual-hosted git repository. sunlan pushed a commit to branch GROOVY-11905 in repository https://gitbox.apache.org/repos/asf/groovy.git
commit 0e593b4bef4fbd77a3d77206a88b810ca93c4bd3 Author: Daniel Sun <[email protected]> AuthorDate: Sun Apr 5 22:28:54 2026 +0900 GROOVY-11905: Optimize non-capturing lambdas --- .../java/org/codehaus/groovy/ast/ClassNode.java | 8 +- .../classgen/asm/sc/StaticTypesLambdaWriter.java | 455 ++++++++--- .../groovy/groovy/transform/stc/LambdaTest.groovy | 865 +++++++++++++++++++++ .../groovy/classgen/asm/TypeAnnotationsTest.groovy | 2 +- 4 files changed, 1229 insertions(+), 101 deletions(-) diff --git a/src/main/java/org/codehaus/groovy/ast/ClassNode.java b/src/main/java/org/codehaus/groovy/ast/ClassNode.java index b91da80224..3bbd62afd8 100644 --- a/src/main/java/org/codehaus/groovy/ast/ClassNode.java +++ b/src/main/java/org/codehaus/groovy/ast/ClassNode.java @@ -1452,17 +1452,21 @@ faces: if (method == null && asBoolean(getInterfaces())) { // GROOVY-11323 return null; } + private List<ClassNode> outerClasses; public List<ClassNode> getOuterClasses() { + List<ClassNode> ocs = outerClasses; + if (ocs != null) return ocs; + ClassNode outer = getOuterClass(); if (outer == null) { - return Collections.emptyList(); + return outerClasses = Collections.emptyList(); } List<ClassNode> result = new ArrayList<>(4); do { result.add(outer); } while ((outer = outer.getOuterClass()) != null); - return result; + return outerClasses = Collections.unmodifiableList(result); } /** diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java index b804702754..efeada8428 100644 --- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java @@ -19,18 +19,23 @@ package org.codehaus.groovy.classgen.asm.sc; import org.codehaus.groovy.GroovyBugError; +import org.codehaus.groovy.ast.ClassCodeExpressionTransformer; import org.codehaus.groovy.ast.ClassHelper; import org.codehaus.groovy.ast.ClassNode; import org.codehaus.groovy.ast.CodeVisitorSupport; import org.codehaus.groovy.ast.ConstructorNode; +import org.codehaus.groovy.ast.FieldNode; import org.codehaus.groovy.ast.InnerClassNode; import org.codehaus.groovy.ast.MethodNode; import org.codehaus.groovy.ast.Parameter; +import org.codehaus.groovy.ast.PropertyNode; +import org.codehaus.groovy.ast.Variable; import org.codehaus.groovy.ast.builder.AstStringCompiler; +import org.codehaus.groovy.ast.expr.ClassExpression; import org.codehaus.groovy.ast.expr.ClosureExpression; -import org.codehaus.groovy.ast.expr.ConstantExpression; import org.codehaus.groovy.ast.expr.Expression; import org.codehaus.groovy.ast.expr.LambdaExpression; +import org.codehaus.groovy.ast.expr.PropertyExpression; import org.codehaus.groovy.ast.expr.VariableExpression; import org.codehaus.groovy.ast.stmt.BlockStatement; import org.codehaus.groovy.ast.stmt.Statement; @@ -42,15 +47,15 @@ import org.codehaus.groovy.classgen.asm.LambdaWriter; import org.codehaus.groovy.classgen.asm.OperandStack; import org.codehaus.groovy.classgen.asm.WriterController; import org.codehaus.groovy.classgen.asm.WriterControllerFactory; +import org.codehaus.groovy.control.SourceUnit; import org.codehaus.groovy.transform.sc.StaticCompilationMetadataKeys; import org.objectweb.asm.MethodVisitor; import java.util.HashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Optional; +import static org.apache.groovy.util.BeanUtils.capitalize; import static org.codehaus.groovy.ast.ClassHelper.CLOSURE_TYPE; import static org.codehaus.groovy.ast.ClassHelper.GENERATED_LAMBDA_TYPE; import static org.codehaus.groovy.ast.ClassHelper.OBJECT_TYPE; @@ -64,6 +69,9 @@ import static org.codehaus.groovy.ast.tools.GeneralUtils.declS; import static org.codehaus.groovy.ast.tools.GeneralUtils.localVarX; import static org.codehaus.groovy.ast.tools.GeneralUtils.returnS; import static org.codehaus.groovy.transform.stc.StaticTypesMarker.CLOSURE_ARGUMENTS; +import static org.codehaus.groovy.transform.stc.StaticTypesMarker.DIRECT_METHOD_CALL_TARGET; +import static org.codehaus.groovy.transform.stc.StaticTypesMarker.IMPLICIT_RECEIVER; +import static org.codehaus.groovy.transform.stc.StaticTypesMarker.INFERRED_TYPE; import static org.codehaus.groovy.transform.stc.StaticTypesMarker.PARAMETER_TYPE; import static org.objectweb.asm.Opcodes.ACC_FINAL; import static org.objectweb.asm.Opcodes.ACC_PRIVATE; @@ -72,6 +80,7 @@ import static org.objectweb.asm.Opcodes.ACC_STATIC; import static org.objectweb.asm.Opcodes.ALOAD; import static org.objectweb.asm.Opcodes.CHECKCAST; import static org.objectweb.asm.Opcodes.DUP; +import static org.objectweb.asm.Opcodes.H_INVOKESTATIC; import static org.objectweb.asm.Opcodes.H_INVOKEVIRTUAL; import static org.objectweb.asm.Opcodes.ICONST_0; import static org.objectweb.asm.Opcodes.INVOKESPECIAL; @@ -83,13 +92,6 @@ import static org.objectweb.asm.Opcodes.NEW; */ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFunctionalInterfaceWriter { - private static final String IS_GENERATED_CONSTRUCTOR = "__IS_GENERATED_CONSTRUCTOR"; - private static final String LAMBDA_SHARED_VARIABLES = "__LAMBDA_SHARED_VARIABLES"; - private static final String DO_CALL = "doCall"; - - private final Map<Expression, ClassNode> lambdaClassNodes = new HashMap<>(); - private final StaticTypesClosureWriter staticTypesClosureWriter; - public StaticTypesLambdaWriter(final WriterController controller) { super(controller); this.staticTypesClosureWriter = new StaticTypesClosureWriter(controller); @@ -98,84 +100,125 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun @Override public void writeLambda(final LambdaExpression expression) { // functional interface target is required for native lambda generation - ClassNode functionalType = expression.getNodeMetaData(PARAMETER_TYPE); + ClassNode functionalType = expression.getNodeMetaData(PARAMETER_TYPE); MethodNode abstractMethod = ClassHelper.findSAM(functionalType); - if (abstractMethod == null || !functionalType.isInterface()) { + if (functionalType == null || abstractMethod == null || !functionalType.isInterface()) { // generate bytecode for closure super.writeLambda(expression); return; } + boolean serializable = makeSerializableIfNeeded(expression, functionalType); + GeneratedLambda generatedLambda = getOrAddGeneratedLambda(expression, abstractMethod); + + ensureDeserializeLambdaSupport(expression, generatedLambda, serializable); + if (generatedLambda.isCapturing() && !isPreloadedLambdaReceiver(generatedLambda)) { + loadLambdaReceiver(generatedLambda); + } + + writeLambdaFactoryInvocation(functionalType.redirect(), abstractMethod, generatedLambda, serializable); + } + + private static Parameter[] createDeserializeLambdaMethodParams() { + return new Parameter[]{new Parameter(SERIALIZEDLAMBDA_TYPE, "serializedLambda")}; + } + + private boolean makeSerializableIfNeeded(final LambdaExpression expression, final ClassNode functionalType) { if (!expression.isSerializable() && functionalType.implementsInterface(SERIALIZABLE_TYPE)) { expression.setSerializable(true); } + return expression.isSerializable(); + } - ClassNode lambdaClass = getOrAddLambdaClass(expression, abstractMethod); - MethodNode lambdaMethod = lambdaClass.getMethods(DO_CALL).get(0); - - boolean canDeserialize = controller.getClassNode().hasMethod(createDeserializeLambdaMethodName(lambdaClass), createDeserializeLambdaMethodParams()); - if (!canDeserialize) { - if (expression.isSerializable()) { - addDeserializeLambdaMethodForEachLambdaExpression(expression, lambdaClass); - addDeserializeLambdaMethod(); - } - newGroovyLambdaWrapperAndLoad(lambdaClass, expression, isAccessingInstanceMembersOfEnclosingClass(lambdaMethod)); + private void ensureDeserializeLambdaSupport(final LambdaExpression expression, final GeneratedLambda generatedLambda, final boolean serializable) { + if (!serializable || hasDeserializeLambdaMethod(generatedLambda.lambdaClass)) { + return; } + addDeserializeLambdaMethodForLambdaExpression(expression, generatedLambda); + addDeserializeLambdaMethod(); + } + + private void writeLambdaFactoryInvocation(final ClassNode functionalType, final MethodNode abstractMethod, final GeneratedLambda generatedLambda, final boolean serializable) { MethodVisitor mv = controller.getMethodVisitor(); mv.visitInvokeDynamicInsn( abstractMethod.getName(), - createAbstractMethodDesc(functionalType.redirect(), lambdaClass), - createBootstrapMethod(controller.getClassNode().isInterface(), expression.isSerializable()), - createBootstrapMethodArguments(createMethodDescriptor(abstractMethod), H_INVOKEVIRTUAL, lambdaClass, lambdaMethod, lambdaMethod.getParameters(), expression.isSerializable()) + createLambdaFactoryMethodDescriptor(functionalType, generatedLambda), + createBootstrapMethod(controller.getClassNode().isInterface(), serializable), + createBootstrapMethodArguments(createMethodDescriptor(abstractMethod), + generatedLambda.getMethodHandleKind(), + generatedLambda.lambdaClass, generatedLambda.lambdaMethod, generatedLambda.lambdaMethod.getParameters(), serializable) ); - if (expression.isSerializable()) { + if (serializable) { mv.visitTypeInsn(CHECKCAST, "java/io/Serializable"); } - controller.getOperandStack().replace(functionalType.redirect(), 1); + if (generatedLambda.nonCapturing()) { + controller.getOperandStack().push(functionalType); + } else { + controller.getOperandStack().replace(functionalType, 1); + } } - private static Parameter[] createDeserializeLambdaMethodParams() { - return new Parameter[]{new Parameter(SERIALIZEDLAMBDA_TYPE, "serializedLambda")}; + private boolean hasDeserializeLambdaMethod(final ClassNode lambdaClass) { + return controller.getClassNode().hasMethod(createDeserializeLambdaMethodName(lambdaClass), createDeserializeLambdaMethodParams()); } - private static boolean isAccessingInstanceMembersOfEnclosingClass(final MethodNode lambdaMethod) { - boolean[] result = new boolean[1]; + private static MethodNode getLambdaMethod(final ClassNode lambdaClass) { + List<MethodNode> lambdaMethods = lambdaClass.getMethods(DO_CALL); + if (lambdaMethods.isEmpty()) { + throw new GroovyBugError("Failed to find the synthetic lambda method in " + lambdaClass.getName()); + } + return lambdaMethods.get(0); + } - lambdaMethod.getCode().visit(new CodeVisitorSupport() { - @Override - public void visitConstantExpression(final ConstantExpression expression) { - if ("this".equals(expression.getValue())) { // as in Type.this.name - result[0] = true; - } - } - @Override - public void visitVariableExpression(final VariableExpression expression) { - if ("this".equals(expression.getName()) || "thisObject".equals(expression.getName())) { - result[0] = true; - } else { - var owner = expression.getNodeMetaData(StaticCompilationMetadataKeys.PROPERTY_OWNER); - if (owner != null && lambdaMethod.getDeclaringClass().getOuterClasses().contains(owner)) { - result[0] = true; - } - } + private static ConstructorNode getGeneratedConstructor(final ClassNode lambdaClass) { + for (ConstructorNode constructorNode : lambdaClass.getDeclaredConstructors()) { + if (Boolean.TRUE.equals(constructorNode.getNodeMetaData(GeneratedConstructorMarker.class))) { + return constructorNode; } - }); + } + throw new GroovyBugError("Failed to find the generated constructor in " + lambdaClass.getName()); + } + + /** + * Determines whether the synthetic lambda body needs an enclosing receiver. + * <p> + * Explicit instance access such as {@code this}, {@code super}, {@code thisObject}, + * and qualified references like {@code Outer.this} or {@code Outer.super} must stay on + * the capturing path. Plain string literals such as {@code "this"} must not disable the + * non-capturing optimization. + */ + private static boolean isAccessingInstanceMembersOfEnclosingClass(final MethodNode lambdaMethod) { + Boolean accessingInstanceMembers = lambdaMethod.getNodeMetaData(AccessesInstanceMembersMarker.class); + if (accessingInstanceMembers != null) { + return accessingInstanceMembers; + } + + InstanceMemberAccessFinder finder = new InstanceMemberAccessFinder(lambdaMethod.getDeclaringClass().getOuterClasses()); + lambdaMethod.getCode().visit(finder); - return result[0]; + accessingInstanceMembers = finder.isAccessingInstanceMembers(); + lambdaMethod.putNodeMetaData(AccessesInstanceMembersMarker.class, accessingInstanceMembers); + return accessingInstanceMembers; } - private void newGroovyLambdaWrapperAndLoad(final ClassNode lambdaClass, final LambdaExpression expression, final boolean accessingInstanceMembers) { + private boolean isPreloadedLambdaReceiver(final GeneratedLambda generatedLambda) { + MethodNode enclosingMethod = controller.getMethodNode(); + return enclosingMethod != null + && enclosingMethod.getNodeMetaData(PreloadedLambdaReceiverMarker.class) == generatedLambda.lambdaClass; + } + + private void loadLambdaReceiver(final GeneratedLambda generatedLambda) { CompileStack compileStack = controller.getCompileStack(); OperandStack operandStack = controller.getOperandStack(); MethodVisitor mv = controller.getMethodVisitor(); - String lambdaClassInternalName = BytecodeHelper.getClassInternalName(lambdaClass); + String lambdaClassInternalName = BytecodeHelper.getClassInternalName(generatedLambda.lambdaClass); mv.visitTypeInsn(NEW, lambdaClassInternalName); mv.visitInsn(DUP); - if (controller.isStaticMethod() || compileStack.isInSpecialConstructorCall() || !accessingInstanceMembers) { + if (controller.isStaticMethod() || compileStack.isInSpecialConstructorCall() || !generatedLambda.accessingInstanceMembers) { classX(controller.getThisType()).visit(controller.getAcg()); } else { loadThis(); @@ -183,23 +226,15 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun operandStack.dup(); - loadSharedVariables(expression); - - Optional<ConstructorNode> generatedConstructor = lambdaClass.getDeclaredConstructors().stream() - .filter(ctor -> Boolean.TRUE.equals(ctor.getNodeMetaData(IS_GENERATED_CONSTRUCTOR))).findFirst(); - if (generatedConstructor.isEmpty()) { - throw new GroovyBugError("Failed to find the generated constructor"); - } + loadSharedVariables(generatedLambda.sharedVariables); - Parameter[] lambdaClassConstructorParameters = generatedConstructor.get().getParameters(); - mv.visitMethodInsn(INVOKESPECIAL, lambdaClassInternalName, "<init>", BytecodeHelper.getMethodDescriptor(VOID_TYPE, lambdaClassConstructorParameters), lambdaClass.isInterface()); + Parameter[] lambdaClassConstructorParameters = generatedLambda.constructor.getParameters(); + mv.visitMethodInsn(INVOKESPECIAL, lambdaClassInternalName, "<init>", BytecodeHelper.getMethodDescriptor(VOID_TYPE, lambdaClassConstructorParameters), generatedLambda.lambdaClass.isInterface()); operandStack.replace(CLOSURE_TYPE, lambdaClassConstructorParameters.length); } - private void loadSharedVariables(final LambdaExpression expression) { - Parameter[] lambdaSharedVariableParameters = expression.getNodeMetaData(LAMBDA_SHARED_VARIABLES); - + private void loadSharedVariables(final Parameter[] lambdaSharedVariableParameters) { for (Parameter parameter : lambdaSharedVariableParameters) { loadReference(parameter.getName(), controller); if (parameter.getNodeMetaData(UseExistingReference.class) == null) { @@ -208,20 +243,35 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun } } - private String createAbstractMethodDesc(final ClassNode functionalInterface, final ClassNode lambdaClass) { - List<Parameter> lambdaSharedVariables = new LinkedList<>(); - prependParameter(lambdaSharedVariables, "__lambda_this", lambdaClass); - return BytecodeHelper.getMethodDescriptor(functionalInterface, lambdaSharedVariables.toArray(Parameter.EMPTY_ARRAY)); + private String createLambdaFactoryMethodDescriptor(final ClassNode functionalInterface, final GeneratedLambda generatedLambda) { + if (generatedLambda.nonCapturing()) { + return BytecodeHelper.getMethodDescriptor(functionalInterface, Parameter.EMPTY_ARRAY); + } + return BytecodeHelper.getMethodDescriptor(functionalInterface, new Parameter[]{createLambdaReceiverParameter(generatedLambda.lambdaClass)}); } - private ClassNode getOrAddLambdaClass(final LambdaExpression expression, final MethodNode abstractMethod) { - return lambdaClassNodes.computeIfAbsent(expression, expr -> { - ClassNode lambdaClass = createLambdaClass((LambdaExpression) expr, ACC_FINAL | ACC_PUBLIC | ACC_STATIC, abstractMethod); + private static Parameter createLambdaReceiverParameter(final ClassNode lambdaClass) { + Parameter parameter = new Parameter(lambdaClass, "__lambda_this"); + parameter.setClosureSharedVariable(false); + return parameter; + } + + private GeneratedLambda getOrAddGeneratedLambda(final LambdaExpression expression, final MethodNode abstractMethod) { + return generatedLambdas.computeIfAbsent(expression, expr -> { + ClassNode lambdaClass = createLambdaClass(expr, ACC_FINAL | ACC_PUBLIC | ACC_STATIC, abstractMethod); controller.getAcg().addInnerClass(lambdaClass); lambdaClass.addInterface(GENERATED_LAMBDA_TYPE); lambdaClass.putNodeMetaData(StaticCompilationMetadataKeys.STATIC_COMPILE_NODE, Boolean.TRUE); lambdaClass.putNodeMetaData(WriterControllerFactory.class, (WriterControllerFactory) x -> controller); - return lambdaClass; + MethodNode lambdaMethod = getLambdaMethod(lambdaClass); + return new GeneratedLambda( + lambdaClass, + lambdaMethod, + getGeneratedConstructor(lambdaClass), + getStoredLambdaSharedVariables(expr), + !requiresLambdaInstance(lambdaMethod), + isAccessingInstanceMembersOfEnclosingClass(lambdaMethod) + ); }); } @@ -252,11 +302,11 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun MethodNode syntheticLambdaMethodNode = addSyntheticLambdaMethodNode(expression, lambdaClass, abstractMethod); - Parameter[] localVariableParameters = expression.getNodeMetaData(LAMBDA_SHARED_VARIABLES); + Parameter[] localVariableParameters = getStoredLambdaSharedVariables(expression); addFieldsForLocalVariables(lambdaClass, localVariableParameters); ConstructorNode constructorNode = addConstructor(expression, localVariableParameters, lambdaClass, createBlockStatementForConstructor(expression, outermostClass, enclosingClass)); - constructorNode.putNodeMetaData(IS_GENERATED_CONSTRUCTOR, Boolean.TRUE); + constructorNode.putNodeMetaData(GeneratedConstructorMarker.class, Boolean.TRUE); syntheticLambdaMethodNode.getCode().visit(new CorrectAccessedVariableVisitor(lambdaClass)); @@ -274,7 +324,7 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun Parameter[] localVariableParameters = getLambdaSharedVariables(expression); removeInitialValues(localVariableParameters); - expression.putNodeMetaData(LAMBDA_SHARED_VARIABLES, localVariableParameters); + expression.putNodeMetaData(StoredLambdaSharedVariablesMarker.class, localVariableParameters); MethodNode doCallMethod = lambdaClass.addMethod( DO_CALL, @@ -285,9 +335,42 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun expression.getCode() ); doCallMethod.setSourcePosition(expression); + if (isNonCapturingLambda(doCallMethod, localVariableParameters)) { + qualifyOuterStaticMemberReferences(doCallMethod); + doCallMethod.setModifiers(doCallMethod.getModifiers() | ACC_STATIC); + } return doCallMethod; } + private static boolean isNonCapturingLambda(final MethodNode lambdaMethod, final Parameter[] lambdaSharedVariables) { + return (lambdaSharedVariables == null || lambdaSharedVariables.length == 0) + && !isAccessingInstanceMembersOfEnclosingClass(lambdaMethod); + } + + private void qualifyOuterStaticMemberReferences(final MethodNode lambdaMethod) { + lambdaMethod.getCode().visit(new ClassCodeExpressionTransformer() { + @Override + protected SourceUnit getSourceUnit() { + return controller.getSourceUnit(); + } + + @Override + public Expression transform(final Expression expression) { + if (expression instanceof VariableExpression variableExpression) { + ClassNode owner = getStaticOuterMemberOwner(variableExpression, lambdaMethod.getDeclaringClass().getOuterClasses()); + if (owner != null) { + PropertyExpression qualifiedReference = new PropertyExpression(classX(owner), variableExpression.getName()); + qualifiedReference.setImplicitThis(false); + qualifiedReference.copyNodeMetaData(variableExpression); + setSourcePosition(qualifiedReference, variableExpression); + return qualifiedReference; + } + } + return super.transform(expression); + } + }); + } + private Parameter[] createParametersWithExactType(final LambdaExpression expression, final MethodNode abstractMethod) { Parameter[] targetParameters = abstractMethod.getParameters(); Parameter[] lambdaParameters = getParametersSafe(expression); @@ -325,38 +408,214 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun code); } - private void addDeserializeLambdaMethodForEachLambdaExpression(final LambdaExpression expression, final ClassNode lambdaClass) { + private static boolean requiresLambdaInstance(final MethodNode lambdaMethod) { + return 0 == (lambdaMethod.getModifiers() & ACC_STATIC); + } + + private static ClassNode getStaticOuterMemberOwner(final VariableExpression expression, final List<ClassNode> outerClasses) { + ClassNode owner = expression.getNodeMetaData(StaticCompilationMetadataKeys.PROPERTY_OWNER); + if (owner == null || !outerClasses.contains(owner)) { + return null; + } + return isStaticOuterMemberReference(expression, owner) ? owner : null; + } + + private static boolean isStaticOuterMemberReference(final VariableExpression expression, final ClassNode owner) { + Variable accessedVariable = expression.getAccessedVariable(); + if (accessedVariable instanceof FieldNode) { + return accessedVariable.isStatic(); + } + if (accessedVariable instanceof PropertyNode) { + return accessedVariable.isStatic(); + } + if (accessedVariable instanceof Parameter) { + return false; + } + + MethodNode directMethodCallTarget = expression.getNodeMetaData(DIRECT_METHOD_CALL_TARGET); + if (directMethodCallTarget != null) { + return directMethodCallTarget.isStatic() && directMethodCallTarget.getParameters().length == 0; + } + + return isStaticMemberNamed(expression.getName(), owner); + } + + private static boolean isStaticMemberNamed(final String propertyName, final ClassNode owner) { + FieldNode field = owner.getField(propertyName); + if (field != null && field.isStatic()) { + return true; + } + + PropertyNode property = owner.getProperty(propertyName); + if (property != null && property.isStatic()) { + return true; + } + + MethodNode getter = owner.getGetterMethod("is" + capitalize(propertyName)); + if (getter == null) { + getter = owner.getGetterMethod("get" + capitalize(propertyName)); + } + return getter != null && getter.isStatic(); + } + + private void addDeserializeLambdaMethodForLambdaExpression(final LambdaExpression expression, final GeneratedLambda generatedLambda) { ClassNode enclosingClass = controller.getClassNode(); - Statement code = block( - new BytecodeSequence(new BytecodeInstruction() { - @Override - public void visit(final MethodVisitor mv) { - mv.visitVarInsn(ALOAD, 0); - mv.visitInsn(ICONST_0); - mv.visitMethodInsn( - INVOKEVIRTUAL, - "java/lang/invoke/SerializedLambda", - "getCapturedArg", - "(I)Ljava/lang/Object;", - false); - mv.visitTypeInsn(CHECKCAST, BytecodeHelper.getClassInternalName(lambdaClass)); - OperandStack operandStack = controller.getOperandStack(); - operandStack.push(lambdaClass); - } - }), - returnS(expression) - ); + Statement code; + if (generatedLambda.nonCapturing()) { + code = block(returnS(expression)); + } else { + code = block( + new BytecodeSequence(new BytecodeInstruction() { + @Override + public void visit(final MethodVisitor mv) { + mv.visitVarInsn(ALOAD, 0); + mv.visitInsn(ICONST_0); + mv.visitMethodInsn( + INVOKEVIRTUAL, + "java/lang/invoke/SerializedLambda", + "getCapturedArg", + "(I)Ljava/lang/Object;", + false); + mv.visitTypeInsn(CHECKCAST, BytecodeHelper.getClassInternalName(generatedLambda.lambdaClass)); + OperandStack operandStack = controller.getOperandStack(); + operandStack.push(generatedLambda.lambdaClass); + } + }), + returnS(expression) + ); + } - enclosingClass.addSyntheticMethod( - createDeserializeLambdaMethodName(lambdaClass), + MethodNode deserializeLambdaMethod = enclosingClass.addSyntheticMethod( + createDeserializeLambdaMethodName(generatedLambda.lambdaClass), ACC_PUBLIC | ACC_STATIC, OBJECT_TYPE, createDeserializeLambdaMethodParams(), ClassNode.EMPTY_ARRAY, code); + if (generatedLambda.isCapturing()) { + // The deserialize helper preloads the captured receiver before it reuses the original lambda expression. + deserializeLambdaMethod.putNodeMetaData(PreloadedLambdaReceiverMarker.class, generatedLambda.lambdaClass); + } } private static String createDeserializeLambdaMethodName(final ClassNode lambdaClass) { return "$deserializeLambda_" + lambdaClass.getName().replace('.', '$') + "$"; } + + private static Parameter[] getStoredLambdaSharedVariables(final LambdaExpression expression) { + Parameter[] sharedVariables = expression.getNodeMetaData(StoredLambdaSharedVariablesMarker.class); + if (sharedVariables == null) { + throw new GroovyBugError("Failed to find shared variables for lambda expression"); + } + return sharedVariables; + } + + private abstract static class GeneratedConstructorMarker { + } + + private abstract static class StoredLambdaSharedVariablesMarker { + } + + private abstract static class AccessesInstanceMembersMarker { + } + + private abstract static class PreloadedLambdaReceiverMarker { + } + + private record GeneratedLambda(ClassNode lambdaClass, MethodNode lambdaMethod, ConstructorNode constructor, + Parameter[] sharedVariables, boolean nonCapturing, + boolean accessingInstanceMembers) { + + private boolean isCapturing() { + return !nonCapturing; + } + + private int getMethodHandleKind() { + return nonCapturing ? H_INVOKESTATIC : H_INVOKEVIRTUAL; + } + } + + private static final class InstanceMemberAccessFinder extends CodeVisitorSupport { + + private final List<ClassNode> outerClasses; + private boolean accessingInstanceMembers; + + private InstanceMemberAccessFinder(final List<ClassNode> outerClasses) { + this.outerClasses = outerClasses; + } + + @Override + public void visitVariableExpression(final VariableExpression expression) { + if (accessingInstanceMembers) { + return; + } + if (expression.isThisExpression() || expression.isSuperExpression() || "thisObject".equals(expression.getName())) { + accessingInstanceMembers = true; + return; + } + + ClassNode owner = expression.getNodeMetaData(StaticCompilationMetadataKeys.PROPERTY_OWNER); + if (owner != null && outerClasses.contains(owner) && getStaticOuterMemberOwner(expression, outerClasses) == null) { + accessingInstanceMembers = true; + return; + } + + super.visitVariableExpression(expression); + } + + @Override + public void visitPropertyExpression(final PropertyExpression expression) { + if (accessingInstanceMembers) { + return; + } + if (isImplicitOuterStaticPropertyReference(expression)) { + return; + } + if (isQualifiedEnclosingInstanceReference(expression)) { + accessingInstanceMembers = true; + return; + } + + super.visitPropertyExpression(expression); + } + + private boolean isImplicitOuterStaticPropertyReference(final PropertyExpression expression) { + if (!(expression.getObjectExpression() instanceof VariableExpression receiver)) { + return false; + } + + if (receiver.getNodeMetaData(IMPLICIT_RECEIVER) == null) { + return false; + } + + ClassNode owner = receiver.getNodeMetaData(INFERRED_TYPE); + if (owner == null || !outerClasses.contains(owner)) { + return false; + } + + MethodNode directMethodCallTarget = expression.getNodeMetaData(DIRECT_METHOD_CALL_TARGET); + return (directMethodCallTarget != null + && directMethodCallTarget.isStatic() + && directMethodCallTarget.getParameters().length == 0 + && outerClasses.contains(directMethodCallTarget.getDeclaringClass())) + || isStaticMemberNamed(expression.getPropertyAsString(), owner); + } + + private static boolean isQualifiedEnclosingInstanceReference(final PropertyExpression expression) { + if (!(expression.getObjectExpression() instanceof ClassExpression)) { + return false; + } + + String property = expression.getPropertyAsString(); + return "this".equals(property) || "super".equals(property); + } + + private boolean isAccessingInstanceMembers() { + return accessingInstanceMembers; + } + } + + private static final String DO_CALL = "doCall"; + private final Map<LambdaExpression, GeneratedLambda> generatedLambdas = new HashMap<>(); + private final StaticTypesClosureWriter staticTypesClosureWriter; } diff --git a/src/test/groovy/groovy/transform/stc/LambdaTest.groovy b/src/test/groovy/groovy/transform/stc/LambdaTest.groovy index f69eb1a332..a93327587a 100644 --- a/src/test/groovy/groovy/transform/stc/LambdaTest.groovy +++ b/src/test/groovy/groovy/transform/stc/LambdaTest.groovy @@ -18,6 +18,8 @@ */ package groovy.transform.stc +import org.codehaus.groovy.classgen.asm.AbstractBytecodeTestCase +import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import static groovy.test.GroovyAssert.assertScript @@ -1888,4 +1890,867 @@ final class LambdaTest { assert this.class.classLoader.loadClass('Foo$_bar_lambda1').modifiers == 25 // public(1) + static(8) + final(16) ''' } + + // GROOVY-11905 + @Nested + class NonCapturingLambdaOptimizationTest extends AbstractBytecodeTestCase { + @Test + void testNonCapturingLambdaWithFunctionInStaticMethod() { + assertScript shell, ''' + class C { + static void test() { + assert [2, 3, 4] == [1, 2, 3].stream().map(e -> e + 1).toList() + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithFunctionInInstanceMethodWithoutThisAccess() { + assertScript shell, ''' + class C { + void test() { + assert [2, 3, 4] == [1, 2, 3].stream().map(e -> e + 1).toList() + } + } + new C().test() + ''' + } + + @Test + void testNonCapturingLambdaWithPredicate() { + assertScript shell, ''' + class C { + static void test() { + assert [2, 4] == [1, 2, 3, 4].stream().filter(e -> e % 2 == 0).toList() + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithSupplier() { + assertScript shell, ''' + class C { + static void test() { + Supplier<String> s = () -> 'constant' + assert s.get() == 'constant' + assert 'hello' == Optional.<String>empty().orElseGet(() -> 'hello') + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithBiFunction() { + assertScript shell, ''' + class C { + static void test() { + BiFunction<Integer, Integer, Integer> f = (a, b) -> a + b + assert f.apply(3, 4) == 7 + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithComparator() { + assertScript shell, ''' + class C { + static void test() { + assert [3, 2, 1] == [1, 2, 3].stream().sorted((a, b) -> b.compareTo(a)).toList() + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithPrimitiveParameterType() { + assertScript shell, ''' + class C { + static void test() { + IntUnaryOperator op = (int i) -> i * 2 + assert op.applyAsInt(5) == 10 + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithCustomFunctionalInterface() { + assertScript shell, ''' + interface Transformer<I, O> { + O transform(I input) + } + class C { + static void test() { + Transformer<String, Integer> t = (String s) -> s.length() + assert t.transform('hello') == 5 + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaCallingStaticMethodOnly() { + assertScript shell, ''' + class C { + static String prefix() { 'Hi ' } + static void test() { + assert ['Hi 1', 'Hi 2'] == [1, 2].stream().map(e -> C.prefix() + e).toList() + } + } + C.test() + ''' + } + + @Test + void testMultipleNonCapturingLambdasInSameMethod() { + assertScript shell, ''' + class C { + static void test() { + Function<Integer, Integer> f = (Integer x) -> x + 1 + Function<Integer, String> g = (Integer x) -> 'v' + x + Predicate<Integer> p = (Integer x) -> x > 2 + assert f.apply(1) == 2 + assert g.apply(1) == 'v1' + assert p.test(3) && !p.test(1) + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaInStaticInitializerBlock() { + assertScript shell, ''' + class C { + static List<Integer> result + static { result = [1, 2, 3].stream().map(e -> e * 2).toList() } + } + assert C.result == [2, 4, 6] + ''' + } + + @Test + void testNonCapturingLambdaInFieldInitializer() { + assertScript shell, ''' + class C { + IntUnaryOperator op = (int i) -> i + 1 + void test() { assert op.applyAsInt(5) == 6 } + } + new C().test() + ''' + } + + @Test + void testNonCapturingLambdaInInterfaceDefaultMethod() { + assertScript shell, ''' + interface Processor { + default List<Integer> process(List<Integer> input) { + input.stream().map(e -> e + 1).toList() + } + } + class C implements Processor {} + assert new C().process([1, 2, 3]) == [2, 3, 4] + ''' + } + + @Test + void testNonCapturingLambdaSingletonIdentity() { + assertScript shell, ''' + class C { + static void test() { + def identities = new HashSet() + for (int i = 0; i < 5; i++) { + Function<Integer, Integer> f = (Integer x) -> x + 1 + identities.add(System.identityHashCode(f)) + } + assert identities.size() == 1 : 'non-capturing lambda should be a singleton' + } + } + C.test() + ''' + } + + @Test + void testCapturingLambdaCreatesDistinctInstances() { + assertScript shell, ''' + class C { + static void test() { + def identities = new HashSet() + for (int i = 0; i < 3; i++) { + int captured = i + Function<Integer, Integer> f = (Integer x) -> x + captured + identities.add(System.identityHashCode(f)) + assert f.apply(10) == 10 + i + } + assert identities.size() == 3 : 'capturing lambda should create different instances' + } + } + C.test() + ''' + } + + @Test + void testCapturingLocalVariable() { + assertScript shell, ''' + class C { + static void test() { + String x = '#' + assert ['#1', '#2'] == [1, 2].stream().map(e -> x + e).toList() + } + } + C.test() + ''' + } + + @Test + void testAccessingThis() { + assertScript shell, ''' + class C { + String prefix = 'Hi ' + void test() { + assert ['Hi 1', 'Hi 2'] == [1, 2].stream().map(e -> this.prefix + e).toList() + } + } + new C().test() + ''' + } + + @Test + void testCallingInstanceMethod() { + assertScript shell, ''' + class C { + String greet(int i) { "Hello $i" } + void test() { + assert ['Hello 1', 'Hello 2'] == [1, 2].stream().map(e -> greet(e)).toList() + } + } + new C().test() + ''' + } + + @Test + void testCallingSuperMethod() { + assertScript shell, ''' + class Base { + String greet(int i) { "Hello $i" } + } + class C extends Base { + void test() { + assert ['Hello 1', 'Hello 2'] == [1, 2].stream().map(e -> super.greet(e)).toList() + } + } + new C().test() + ''' + } + + @Test + void testNonCapturingLambdaWithThisStringLiteralRemainsSingleton() { + assertScript shell, ''' + class C { + static void test() { + def identities = new HashSet() + for (int i = 0; i < 5; i++) { + Supplier<String> supplier = () -> 'this' + identities.add(System.identityHashCode(supplier)) + assert supplier.get() == 'this' + } + assert identities.size() == 1 : 'non-capturing lambda with string literal this should still be a singleton' + } + } + C.test() + ''' + } + + @Test + void testNonCapturingSerializableLambdaCanBeSerialized() { + assertScript shell, ''' + import java.io.* + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + byte[] test() { + try (def out = new ByteArrayOutputStream()) { + out.withObjectOutputStream { + SerFunc<Integer, String> f = ((Integer i) -> 'a' + i) + it.writeObject(f) + } + out.toByteArray() + } + } + assert test().length > 0 + ''' + } + + @Test + void testNonCapturingSerializableLambdaRoundTrips() { + assertScript shell, ''' + package tests.lambda + class C { + static byte[] test() { + def out = new ByteArrayOutputStream() + out.withObjectOutputStream { it -> + SerFunc<Integer, String> f = (Integer i) -> 'a' + i + it.writeObject(f) + } + out.toByteArray() + } + static main(args) { + new ByteArrayInputStream(C.test()).withObjectInputStream(C.classLoader) { + SerFunc<Integer, String> f = (SerFunc<Integer, String>) it.readObject() + assert f.apply(1) == 'a1' + } + } + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + } + ''' + } + + @Test + void testNonCapturingSerializableLambdaSingletonIdentity() { + assertScript shell, ''' + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + class C { + static void test() { + def identities = new HashSet() + for (int i = 0; i < 5; i++) { + SerFunc<Integer, Integer> f = (Integer x) -> x + 1 + identities.add(System.identityHashCode(f)) + } + assert identities.size() == 1 : 'non-capturing serializable lambda should be a singleton' + } + } + C.test() + ''' + } + + @Test + void testCapturingSerializableLambdaStillRoundTrips() { + assertScript shell, ''' + package tests.lambda + class C { + byte[] test() { + def out = new ByteArrayOutputStream() + out.withObjectOutputStream { + String s = 'a' + SerFunc<Integer, String> f = (Integer i) -> s + i + it.writeObject(f) + } + out.toByteArray() + } + static main(args) { + new ByteArrayInputStream(C.newInstance().test()).withObjectInputStream(C.classLoader) { + SerFunc<Integer, String> f = (SerFunc<Integer, String>) it.readObject() + assert f.apply(1) == 'a1' + } + } + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + } + ''' + } + + @Test + void testCapturingLambdaWithRunnable() { + assertScript shell, ''' + import java.util.concurrent.atomic.AtomicBoolean + class C { + static void test() { + AtomicBoolean ran = new AtomicBoolean(false) + Runnable r = () -> ran.set(true) + r.run() + assert ran.get() + } + } + C.test() + ''' + } + + @Test + void testCapturingLambdaWithConsumer() { + assertScript shell, ''' + class C { + static void test() { + def result = [] + Consumer<Integer> c = (Integer x) -> result.add(x * 2) + c.accept(3) + c.accept(5) + assert result == [6, 10] + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaAccessingStaticField() { + assertScript shell, ''' + class C { + static final int OFFSET = 100 + static void test() { + assert [101, 102, 103] == [1, 2, 3].stream().map(e -> e + OFFSET).toList() + } + } + C.test() + ''' + } + + @Test + void testQualifiedOuterThisRemainsCapturing() { + assertScript shell, ''' + class Outer { + String name = 'outer' + class Inner { + void test() { + Function<Integer, String> f = (Integer x) -> Outer.this.name + x + assert f.apply(1) == 'outer1' + } + } + void test() { new Inner().test() } + } + new Outer().test() + ''' + } + + @Test + void testNestedNonCapturingLambdas() { + assertScript shell, ''' + class C { + static void test() { + Function<List<Integer>, List<Integer>> f = (List<Integer> list) -> + list.stream().map(e -> e * 2).toList() + assert f.apply([1, 2, 3]) == [2, 4, 6] + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaInStaticMethodUsesStaticDoCall() { + def bytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', ''' + @CompileStatic + class C { + static IntUnaryOperator create() { + (int i) -> i * 2 + } + } + ''') + assert bytecode.hasStrictSequence([ + 'public static doCall(I)I', + 'L0' + ]) + } + + @Test + void testNonCapturingLambdaInInstanceMethodWithoutThisAccessUsesCaptureFreeInvokeDynamic() { + def bytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', ''' + @CompileStatic + class C { + IntUnaryOperator create() { + (int i) -> i + 1 + } + } + ''') + assert bytecode.hasSequence([ + 'INVOKEDYNAMIC applyAsInt()Ljava/util/function/IntUnaryOperator;', + 'java/lang/invoke/LambdaMetafactory.metafactory', + 'C$_create_lambda1.doCall(I)I' + ]) + assert !bytecode.hasSequence(['NEW C$_create_lambda1']) + } + + @Test + void testNonCapturingLambdaWithThisStringLiteralUsesCaptureFreeInvokeDynamic() { + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', ''' + @CompileStatic + class C { + static Supplier<String> create() { + () -> 'this' + } + } + ''') + assert lambdaBytecode.hasStrictSequence([ + 'public static doCall()Ljava/lang/Object;', + 'L0' + ]) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', ''' + @CompileStatic + class C { + static Supplier<String> create() { + () -> 'this' + } + } + ''') + assert outerBytecode.hasSequence([ + 'INVOKEDYNAMIC get()Ljava/util/function/Supplier;', + 'java/lang/invoke/LambdaMetafactory.metafactory', + 'C$_create_lambda1.doCall()Ljava/lang/Object;' + ]) + assert !outerBytecode.hasSequence(['NEW C$_create_lambda1']) + } + + @Test + void testCapturingLambdaRetainsInstanceDoCallAndCapturedReceiver() { + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', ''' + @CompileStatic + class C { + static IntUnaryOperator create() { + int captured = 1 + IntUnaryOperator op = (int i) -> i + captured + op + } + } + ''') + assert lambdaBytecode.hasSequence(['public doCall(I)I']) + assert !lambdaBytecode.hasSequence(['public static doCall(I)I']) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', ''' + @CompileStatic + class C { + static IntUnaryOperator create() { + int captured = 1 + IntUnaryOperator op = (int i) -> i + captured + op + } + } + ''') + assert outerBytecode.hasSequence([ + 'NEW C$_create_lambda1', + 'INVOKEDYNAMIC applyAsInt(LC$_create_lambda1;)Ljava/util/function/IntUnaryOperator;', + 'C$_create_lambda1.doCall(I)I' + ]) + } + + @Test + void testSuperMethodCallRetainsInstanceDoCallAndCapturedReceiver() { + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', ''' + @CompileStatic + class Base { + String greet(int i) { "Hello $i" } + } + @CompileStatic + class C extends Base { + Function<Integer, String> create() { + (Integer i) -> super.greet(i) + } + } + ''') + assert lambdaBytecode.hasSequence(['public doCall(Ljava/lang/Integer;)Ljava/lang/Object;']) + assert !lambdaBytecode.hasSequence(['public static doCall(Ljava/lang/Integer;)Ljava/lang/Object;']) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', ''' + @CompileStatic + class Base { + String greet(int i) { "Hello $i" } + } + @CompileStatic + class C extends Base { + Function<Integer, String> create() { + (Integer i) -> super.greet(i) + } + } + ''') + assert outerBytecode.hasSequence([ + 'NEW C$_create_lambda1', + 'INVOKEDYNAMIC apply(LC$_create_lambda1;)Ljava/util/function/Function;', + 'C$_create_lambda1.doCall(Ljava/lang/Integer;)Ljava/lang/Object;' + ]) + } + + @Test + void testNonCapturingSerializableLambdaDeserializeHelperSkipsCapturedArgLookup() { + def bytecode = compileStaticBytecode(classNamePattern: 'C', method: '$deserializeLambda_C$_create_lambda1$', ''' + @CompileStatic + class C { + static SerFunc<Integer, String> create() { + (Integer i) -> 'a' + i + } + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + } + ''') + assert !bytecode.hasSequence([SERIALIZED_LAMBDA_GET_CAPTURED_ARG]) + } + + @Test + void testCapturingSerializableLambdaDeserializeHelperReadsCapturedArg() { + def bytecode = compileStaticBytecode(classNamePattern: 'C', method: '$deserializeLambda_C$_create_lambda1$', ''' + @CompileStatic + class C { + static SerFunc<Integer, String> create() { + String prefix = 'a' + SerFunc<Integer, String> f = (Integer i) -> prefix + i + f + } + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + } + ''') + assert bytecode.hasSequence([SERIALIZED_LAMBDA_GET_CAPTURED_ARG]) + } + + @Test + void testNonCapturingLambdaAccessingStaticFieldUsesCaptureFreeInvokeDynamic() { + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', ''' + @CompileStatic + class C { + static final int OFFSET = 100 + static IntUnaryOperator create() { + (int i) -> i + OFFSET + } + } + ''') + assert lambdaBytecode.hasStrictSequence([ + 'public static doCall(I)I', + 'L0' + ]) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', ''' + @CompileStatic + class C { + static final int OFFSET = 100 + static IntUnaryOperator create() { + (int i) -> i + OFFSET + } + } + ''') + assert outerBytecode.hasSequence([ + 'INVOKEDYNAMIC applyAsInt()Ljava/util/function/IntUnaryOperator;', + 'java/lang/invoke/LambdaMetafactory.metafactory', + 'C$_create_lambda1.doCall(I)I' + ]) + assert !outerBytecode.hasSequence(['NEW C$_create_lambda1']) + } + + @Test + void testNonCapturingLambdaCallingQualifiedStaticMethodOnlyUsesCaptureFreeInvokeDynamic() { + def script = ''' + @CompileStatic + class C { + static String prefix() { 'Hi ' } + static Function<Integer, String> create() { + (Integer i) -> C.prefix() + i + } + } + ''' + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', script) + assert lambdaBytecode.hasSequence(['public static doCall(Ljava/lang/Integer;)Ljava/lang/Object;']) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', script) + assert outerBytecode.hasSequence([ + 'INVOKEDYNAMIC apply()Ljava/util/function/Function;', + 'java/lang/invoke/LambdaMetafactory.metafactory', + 'C$_create_lambda1.doCall(Ljava/lang/Integer;)Ljava/lang/Object;' + ]) + assert !outerBytecode.hasSequence(['NEW C$_create_lambda1']) + } + + @Test + void testNonCapturingComparatorLambdaUsesCaptureFreeInvokeDynamic() { + def script = ''' + @CompileStatic + class C { + static java.util.Comparator<Integer> create() { + (Integer left, Integer right) -> right.compareTo(left) + } + } + ''' + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', script) + assert lambdaBytecode.hasSequence(['public static doCall(Ljava/lang/Integer;Ljava/lang/Integer;)I']) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', script) + assert outerBytecode.hasSequence([ + 'INVOKEDYNAMIC compare()Ljava/util/Comparator;', + 'java/lang/invoke/LambdaMetafactory.metafactory', + 'C$_create_lambda1.doCall(Ljava/lang/Integer;Ljava/lang/Integer;)I' + ]) + assert !outerBytecode.hasSequence(['NEW C$_create_lambda1']) + } + + @Test + void testNonCapturingLambdaWithCustomFunctionalInterfaceUsesCaptureFreeInvokeDynamic() { + def script = ''' + interface Transformer<I, O> { + O transform(I input) + } + @CompileStatic + class C { + static Transformer<String, Integer> create() { + (String s) -> s.length() + } + } + ''' + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', script) + assert lambdaBytecode.hasSequence(['public static doCall(Ljava/lang/String;)Ljava/lang/Object;']) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', script) + assert outerBytecode.hasSequence([ + 'INVOKEDYNAMIC transform()LTransformer;', + 'java/lang/invoke/LambdaMetafactory.metafactory', + 'C$_create_lambda1.doCall(Ljava/lang/String;)Ljava/lang/Object;' + ]) + assert !outerBytecode.hasSequence(['NEW C$_create_lambda1']) + } + + @Test + void testNonCapturingSerializableLambdaUsesCaptureFreeAltMetafactory() { + def script = ''' + @CompileStatic + class C { + static SerFunc<Integer, String> create() { + (Integer i) -> 'a' + i + } + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + } + ''' + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', script) + assert lambdaBytecode.hasSequence(['public static doCall(Ljava/lang/Integer;)Ljava/lang/Object;']) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', script) + assert outerBytecode.hasSequence([ + 'INVOKEDYNAMIC apply()LC$SerFunc;', + 'java/lang/invoke/LambdaMetafactory.altMetafactory', + 'C$_create_lambda1.doCall(Ljava/lang/Integer;)Ljava/lang/Object;' + ]) + assert outerBytecode.hasSequence(['CHECKCAST java/io/Serializable']) + assert !outerBytecode.hasSequence(['NEW C$_create_lambda1']) + } + + @Test + void testNonCapturingLambdaInStaticInitializerUsesCaptureFreeInvokeDynamic() { + def script = ''' + @CompileStatic + class C { + static IntUnaryOperator op + static { + op = (int i) -> i + 1 + } + } + ''' + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$__clinit__lambda1', method: 'doCall', script) + assert lambdaBytecode.hasSequence(['public static doCall(I)I']) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: '<clinit>', script) + assert outerBytecode.hasSequence([ + 'INVOKEDYNAMIC applyAsInt()Ljava/util/function/IntUnaryOperator;', + 'java/lang/invoke/LambdaMetafactory.metafactory', + 'C$__clinit__lambda1.doCall(I)I' + ]) + assert !outerBytecode.hasSequence(['NEW C$__clinit__lambda1']) + } + + @Test + void testNonCapturingLambdaInFieldInitializerUsesCaptureFreeInvokeDynamic() { + def script = ''' + @CompileStatic + class C { + IntUnaryOperator op = (int i) -> i + 1 + } + ''' + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$_lambda1', method: 'doCall', script) + assert lambdaBytecode.hasSequence(['public static doCall(I)I']) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: '<init>', script) + assert outerBytecode.hasSequence([ + 'INVOKEDYNAMIC applyAsInt()Ljava/util/function/IntUnaryOperator;', + 'java/lang/invoke/LambdaMetafactory.metafactory', + 'C$_lambda1.doCall(I)I' + ]) + assert !outerBytecode.hasSequence(['NEW C$_lambda1']) + } + + @Test + void testNonCapturingLambdaInInterfaceDefaultMethodUsesCaptureFreeInvokeDynamic() { + def script = ''' + @CompileStatic + interface Processor { + default IntUnaryOperator process() { + (int i) -> i + 1 + } + } + ''' + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'Processor\\$_process_lambda1', method: 'doCall', script) + assert lambdaBytecode.hasSequence(['public static doCall(I)I']) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'Processor', method: 'process', script) + assert outerBytecode.hasSequence([ + 'INVOKEDYNAMIC applyAsInt()Ljava/util/function/IntUnaryOperator;', + 'java/lang/invoke/LambdaMetafactory.metafactory', + 'Processor$_process_lambda1.doCall(I)I' + ]) + assert !outerBytecode.hasSequence(['NEW Processor$_process_lambda1']) + } + + @Test + void testNonCapturingLambdaWithExceptionInBody() { + assertScript shell, ''' + class C { + static void test() { + Function<String, Integer> f = (String s) -> { + if (s == null) throw new IllegalArgumentException('null input') + return s.length() + } + assert f.apply('hello') == 5 + try { + f.apply(null) + assert false : 'should have thrown' + } catch (IllegalArgumentException e) { + assert e.message == 'null input' + } + } + } + C.test() + ''' + } + + @Test + void testAccessingThisObjectRemainsCapturing() { + assertScript shell, ''' + class C { + String name = 'test' + void test() { + Function<Integer, String> f = (Integer x) -> thisObject.name + x + assert f.apply(1) == 'test1' + } + } + new C().test() + ''' + } + + @Test + void testAccessingThisObjectRetainsInstanceDoCall() { + def bytecode = compileStaticBytecode(classNamePattern: 'C\\$_test_lambda1', method: 'doCall', ''' + @CompileStatic + class C { + String name = 'test' + Function<Integer, String> test() { + (Integer x) -> thisObject.name + x + } + } + ''') + assert bytecode.hasSequence(['public doCall(Ljava/lang/Integer;)Ljava/lang/Object;']) + assert !bytecode.hasSequence(['public static doCall(Ljava/lang/Integer;)Ljava/lang/Object;']) + } + + + private compileStaticBytecode(final Map options = [:], final String script) { + compile(options, COMMON_IMPORTS + script) + } + + private static final String COMMON_IMPORTS = '''\ + import groovy.transform.CompileStatic + import java.io.Serializable + import java.util.function.Consumer + import java.util.function.Function + import java.util.function.IntUnaryOperator + import java.util.function.Supplier + '''.stripIndent() + private static final String SERIALIZED_LAMBDA_GET_CAPTURED_ARG = 'INVOKEVIRTUAL java/lang/invoke/SerializedLambda.getCapturedArg' + } } diff --git a/src/test/groovy/org/codehaus/groovy/classgen/asm/TypeAnnotationsTest.groovy b/src/test/groovy/org/codehaus/groovy/classgen/asm/TypeAnnotationsTest.groovy index c5270ddec5..cdcf626338 100644 --- a/src/test/groovy/org/codehaus/groovy/classgen/asm/TypeAnnotationsTest.groovy +++ b/src/test/groovy/org/codehaus/groovy/classgen/asm/TypeAnnotationsTest.groovy @@ -299,7 +299,7 @@ final class TypeAnnotationsTest extends AbstractBytecodeTestCase { } ''') assert bytecode.hasStrictSequence([ - 'public doCall(I)I', + 'public static doCall(I)I', '@LTypeAnno1;() : METHOD_FORMAL_PARAMETER 0, null', 'L0' ])
