This is an automated email from the ASF dual-hosted git repository. asf-gitbox-commits pushed a commit to branch GROOVY-11993 in repository https://gitbox.apache.org/repos/asf/groovy.git
commit 6b78dfadaa08dbd71f5e8e3f528799e9721b8626 Author: Daniel Sun <[email protected]> AuthorDate: Tue May 5 02:33:46 2026 +0900 GROOVY-11993: Support serializable method reference --- .../asm/sc/AbstractFunctionalInterfaceWriter.java | 348 +++++++++++++--- .../StaticTypesFunctionalInterfaceMetadataKey.java | 52 +++ .../classgen/asm/sc/StaticTypesLambdaAnalyzer.java | 25 +- .../classgen/asm/sc/StaticTypesLambdaWriter.java | 123 +++--- ...StaticTypesMethodReferenceExpressionWriter.java | 442 ++++++++++++++++----- .../groovy/groovy/transform/stc/LambdaTest.groovy | 2 +- .../transform/stc/MethodReferenceTest.groovy | 415 +++++++++++++++++++ 7 files changed, 1150 insertions(+), 257 deletions(-) diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionalInterfaceWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionalInterfaceWriter.java index ceb280c43b..7d437bf566 100644 --- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionalInterfaceWriter.java +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionalInterfaceWriter.java @@ -22,91 +22,89 @@ import org.codehaus.groovy.ast.ClassHelper; import org.codehaus.groovy.ast.ClassNode; import org.codehaus.groovy.ast.MethodNode; import org.codehaus.groovy.ast.Parameter; +import org.codehaus.groovy.ast.expr.ConstructorCallExpression; +import org.codehaus.groovy.ast.expr.Expression; +import org.codehaus.groovy.ast.expr.MethodCallExpression; +import org.codehaus.groovy.ast.stmt.BlockStatement; +import org.codehaus.groovy.ast.stmt.Statement; +import org.codehaus.groovy.ast.tools.GeneralUtils; +import org.codehaus.groovy.classgen.asm.BytecodeHelper; +import org.codehaus.groovy.classgen.asm.WriterController; import org.codehaus.groovy.syntax.RuntimeParserException; +import org.codehaus.groovy.transform.sc.StaticCompilationMetadataKeys; import org.objectweb.asm.Handle; +import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; import org.objectweb.asm.Type; -import java.util.List; +import java.util.Arrays; +import static org.codehaus.groovy.ast.ClassHelper.OBJECT_TYPE; +import static org.codehaus.groovy.ast.ClassHelper.SERIALIZEDLAMBDA_TYPE; import static org.codehaus.groovy.ast.ClassHelper.getUnwrapper; import static org.codehaus.groovy.ast.ClassHelper.getWrapper; import static org.codehaus.groovy.ast.ClassHelper.isDynamicTyped; import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveType; +import static org.codehaus.groovy.ast.tools.GeneralUtils.args; +import static org.codehaus.groovy.ast.tools.GeneralUtils.boolX; +import static org.codehaus.groovy.ast.tools.GeneralUtils.callX; +import static org.codehaus.groovy.ast.tools.GeneralUtils.classX; +import static org.codehaus.groovy.ast.tools.GeneralUtils.constX; +import static org.codehaus.groovy.ast.tools.GeneralUtils.ctorX; +import static org.codehaus.groovy.ast.tools.GeneralUtils.eqX; +import static org.codehaus.groovy.ast.tools.GeneralUtils.ifS; +import static org.codehaus.groovy.ast.tools.GeneralUtils.param; +import static org.codehaus.groovy.ast.tools.GeneralUtils.params; +import static org.codehaus.groovy.ast.tools.GeneralUtils.returnS; +import static org.codehaus.groovy.ast.tools.GeneralUtils.throwS; +import static org.codehaus.groovy.ast.tools.GeneralUtils.varX; import static org.codehaus.groovy.ast.tools.GenericsUtils.hasUnresolvedGenerics; import static org.codehaus.groovy.classgen.asm.BytecodeHelper.getClassInternalName; import static org.codehaus.groovy.classgen.asm.BytecodeHelper.getMethodDescriptor; +import static org.codehaus.groovy.classgen.asm.sc.StaticTypesFunctionalInterfaceMetadataKey.DESERIALIZE_LAMBDA_DISPATCHER; +import static org.codehaus.groovy.transform.stc.StaticTypesMarker.DIRECT_METHOD_CALL_TARGET; +import static org.objectweb.asm.Opcodes.CHECKCAST; /** - * Represents functional interface writer which contains some common methods to complete generating bytecode + * Shared bytecode and deserialization support for statically-compiled functional interface implementations, + * including both lambdas and method references. + * * @since 3.0.0 */ public interface AbstractFunctionalInterfaceWriter { - default String createMethodDescriptor(final MethodNode method) { - return getMethodDescriptor(method.getReturnType(), method.getParameters()); - } - - default Handle createBootstrapMethod(final boolean isInterface, final boolean serializable) { - return new Handle( - Opcodes.H_INVOKESTATIC, - "java/lang/invoke/LambdaMetafactory", - serializable ? "altMetafactory" : "metafactory", - serializable ? "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;" - : "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;", - false // GROOVY-8299, GROOVY-8989, GROOVY-11265 + default void writeFunctionalInterfaceIndy(final MethodVisitor methodVisitor, + final String samMethodName, final String invokedTypeDescriptor, + final String samMethodDescriptor, final int implMethodKind, + final ClassNode implClassNode, final MethodNode implMethodNode, + final Parameter[] implMethodParameters, final boolean serializable) { + methodVisitor.visitInvokeDynamicInsn( + samMethodName, + invokedTypeDescriptor, + createBootstrapMethod(serializable), + createBootstrapMethodArguments(samMethodDescriptor, implMethodKind, implClassNode, implMethodNode, implMethodParameters, serializable) ); - } - - default Object[] createBootstrapMethodArguments(final String abstractMethodDesc, final int insn, final ClassNode methodOwner, final MethodNode methodNode, final Parameter[] parameters, final boolean serializable) { - ClassNode returnType = methodNode.getReturnType(); - switch (Type.getReturnType(abstractMethodDesc).getSort()) { - case Type.BOOLEAN: - if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Boolean_TYPE; // GROOVY-10975 - break; - case Type.BYTE: - if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Byte_TYPE; - break; - case Type.CHAR: - if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Character_TYPE; - break; - case Type.DOUBLE: - if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Double_TYPE; - break; - case Type.FLOAT: - if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Float_TYPE; - break; - case Type.INT: - if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Integer_TYPE; - break; - case Type.LONG: - if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Long_TYPE; - break; - case Type.SHORT: - if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Short_TYPE; - break; - case Type.VOID: - returnType = ClassHelper.VOID_TYPE; // GROOVY-10933 + if (serializable) { + methodVisitor.visitTypeInsn(CHECKCAST, BytecodeHelper.getClassInternalName(ClassHelper.SERIALIZABLE_TYPE)); } + } - Object[] arguments = !serializable ? new Object[3] : new Object[]{null, null, null, 5, 0}; - - arguments[0] = Type.getMethodType(abstractMethodDesc); - - arguments[1] = new Handle( - insn, // H_INVOKESTATIC or H_INVOKEVIRTUAL or H_INVOKEINTERFACE (GROOVY-9853) - getClassInternalName(methodOwner.getName()), - methodNode.getName(), - getMethodDescriptor(methodNode), - methodOwner.isInterface()); + default String createMethodDescriptor(final MethodNode method) { + return createMethodDescriptor(method.getReturnType(), method.getParameters()); + } - arguments[2] = Type.getMethodType(getMethodDescriptor(returnType, parameters)); + default String createMethodDescriptor(final ClassNode returnType, final Parameter[] parameters) { + return getMethodDescriptor(returnType, parameters); + } - return arguments; + default String createFunctionalInterfaceFactoryDescriptor(final ClassNode functionalType, final Parameter[] capturedParameters) { + return createMethodDescriptor(functionalType.redirect(), capturedParameters); } - default ClassNode convertParameterType(final ClassNode parameterType, final ClassNode inferredType) { - return convertParameterType(parameterType, parameterType, inferredType); + default Parameter createCapturedReceiverParameter(final ClassNode receiverType, final String parameterName) { + Parameter parameter = new Parameter(receiverType, parameterName); + parameter.setClosureSharedVariable(false); + return parameter; } default ClassNode convertParameterType(final ClassNode targetType, final ClassNode parameterType, final ClassNode inferredType) { @@ -128,7 +126,7 @@ public interface AbstractFunctionalInterfaceWriter { // (1) java.lang.invoke.LambdaConversionException: Type mismatch for instantiated parameter 0: class java.lang.Integer is not a subtype of int // (2) java.lang.BootstrapMethodError: bootstrap method initialization exception if (!(isDynamicTyped(parameterType) && isPrimitiveType(targetType)) // (1) - && (parameterType.equals(getUnwrapper(parameterType)) || inferredType.equals(getWrapper(inferredType)))) { // (2) + && (parameterType.equals(getUnwrapper(parameterType)) || inferredType.equals(getWrapper(inferredType)))) { // (2) // The non-primitive type and primitive type are not allowed to mix since Java 9+ // java.lang.invoke.LambdaConversionException: Type mismatch for instantiated parameter 0: int is not a subtype of class java.lang.Object type = getWrapper(inferredType).getPlainNodeReference(); @@ -150,10 +148,230 @@ public interface AbstractFunctionalInterfaceWriter { return type; } - default Parameter prependParameter(final List<Parameter> parameterList, final String parameterName, final ClassNode parameterType) { - Parameter parameter = new Parameter(parameterType, parameterName); - parameter.setClosureSharedVariable(false); - parameterList.add(0, parameter); - return parameter; + default SerializedLambdaFingerprint createSerializedLambdaFingerprint(final String samMethodDescriptor, final ClassNode capturingClass, + final int implMethodKind, final ClassNode implClassNode, + final MethodNode implMethodNode, final Parameter[] implMethodParameters, + final ClassNode functionalType, final MethodNode abstractMethod, + final int capturedArgumentCount) { + return new SerializedLambdaFingerprint( + getClassInternalName(capturingClass), + implMethodKind, + getClassInternalName(implClassNode), + implMethodNode.getName(), + getMethodDescriptor(implMethodNode), + getClassInternalName(functionalType.redirect()), + abstractMethod.getName(), + samMethodDescriptor, + createInstantiatedMethodType(samMethodDescriptor, implMethodNode, implMethodParameters).getDescriptor(), + capturedArgumentCount + ); + } + + default void addDeserializeDispatcherEntry(final WriterController controller, final Parameter[] deserializeMethodParameters, + final SerializedLambdaFingerprint serializedLambdaFingerprint, + final MethodNode helperMethod) { + BlockStatement dispatcherGuards = getOrAddDeserializeDispatcherGuards(controller, deserializeMethodParameters); + MethodCallExpression helperCall = callX(classX(controller.getClassNode()), helperMethod.getName(), args(varX(deserializeMethodParameters[0]))); + helperCall.setImplicitThis(false); + helperCall.setMethodTarget(helperMethod); + + dispatcherGuards.addStatement( + // Keep this guard strict: deserialization must route to exactly one synthetic helper + // whose serialized-lambda fingerprint fully matches the incoming SerializedLambda. + ifS(boolX(matchesSerializedFunctionalInterface(varX(deserializeMethodParameters[0]), serializedLambdaFingerprint)), + returnS(helperCall) + ) + ); + } + + default Parameter[] createDeserializeMethodParameters() { + return new Parameter[] { new Parameter(SERIALIZEDLAMBDA_TYPE, "serializedLambda") }; + } + + private Handle createBootstrapMethod(final boolean serializable) { + return new Handle( + Opcodes.H_INVOKESTATIC, + "java/lang/invoke/LambdaMetafactory", + serializable ? "altMetafactory" : "metafactory", + serializable ? "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;" + : "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;", + false // GROOVY-8299, GROOVY-8989, GROOVY-11265 + ); + } + + private Object[] createBootstrapMethodArguments(final String samMethodDescriptor, final int implMethodKind, + final ClassNode implClassNode, final MethodNode implMethodNode, + final Parameter[] implMethodParameters, final boolean serializable) { + Object[] arguments = !serializable ? new Object[3] : new Object[]{null, null, null, 5, 0}; + + arguments[0] = Type.getMethodType(samMethodDescriptor); + + arguments[1] = new Handle( + implMethodKind, // H_INVOKESTATIC or H_INVOKEVIRTUAL or H_INVOKEINTERFACE (GROOVY-9853) + getClassInternalName(implClassNode.getName()), + implMethodNode.getName(), + getMethodDescriptor(implMethodNode), + implClassNode.isInterface()); + + arguments[2] = createInstantiatedMethodType(samMethodDescriptor, implMethodNode, implMethodParameters); + + return arguments; + } + + private Type createInstantiatedMethodType(final String samMethodDescriptor, final MethodNode implMethodNode, final Parameter[] implMethodParameters) { + ClassNode returnType = implMethodNode.getReturnType(); + switch (Type.getReturnType(samMethodDescriptor).getSort()) { + case Type.BOOLEAN: + if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Boolean_TYPE; // GROOVY-10975 + break; + case Type.BYTE: + if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Byte_TYPE; + break; + case Type.CHAR: + if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Character_TYPE; + break; + case Type.DOUBLE: + if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Double_TYPE; + break; + case Type.FLOAT: + if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Float_TYPE; + break; + case Type.INT: + if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Integer_TYPE; + break; + case Type.LONG: + if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Long_TYPE; + break; + case Type.SHORT: + if (returnType.isGenericsPlaceHolder()) returnType = ClassHelper.Short_TYPE; + break; + case Type.VOID: + returnType = ClassHelper.VOID_TYPE; // GROOVY-10933 + } + + return Type.getMethodType(createMethodDescriptor(returnType, implMethodParameters)); + } + + private BlockStatement getOrAddDeserializeDispatcherGuards(final WriterController controller, final Parameter[] deserializeMethodParameters) { + ClassNode enclosingClass = controller.getClassNode(); + BlockStatement dispatcherGuards = enclosingClass.getNodeMetaData(DESERIALIZE_LAMBDA_DISPATCHER); + if (dispatcherGuards != null) { + return dispatcherGuards; + } + + dispatcherGuards = new BlockStatement(); + BlockStatement dispatcher = new BlockStatement(); + dispatcher.addStatement(dispatcherGuards); + dispatcher.addStatement(createInvalidDeserializationStatement()); + + MethodNode deserializeLambda = enclosingClass.addSyntheticMethod( + "$deserializeLambda$", + Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC, + OBJECT_TYPE, + deserializeMethodParameters, + ClassNode.EMPTY_ARRAY, + dispatcher); + deserializeLambda.putNodeMetaData(StaticCompilationMetadataKeys.STATIC_COMPILE_NODE, Boolean.TRUE); + enclosingClass.putNodeMetaData(DESERIALIZE_LAMBDA_DISPATCHER, dispatcherGuards); + return dispatcherGuards; + } + + private Statement createInvalidDeserializationStatement() { + final ClassNode cn = ClassHelper.make(IllegalArgumentException.class); + ConstructorCallExpression exception = ctorX(cn, constX("Invalid serialized functional interface")); + exception.putNodeMetaData(DIRECT_METHOD_CALL_TARGET, cn.getDeclaredConstructor( + params(param(ClassHelper.STRING_TYPE, "message")) + )); + return throwS(exception); + } + + /** + * Builds the identity check used by {@code $deserializeLambda$} to select the + * correct synthetic helper for a serialized lambda/method reference. + * <p> + * The generated expression is a conjunction over all stable + * {@link java.lang.invoke.SerializedLambda} identity fields we emit via + * {@link #createSerializedLambdaFingerprint(String, ClassNode, int, ClassNode, MethodNode, Parameter[], ClassNode, MethodNode, int)}: + * capturing class, implementation kind/class/name/signature, functional interface class/SAM method/signature, + * instantiated method type, and captured argument count. + * <p> + * Do not weaken this predicate (for example, by checking only method name/class or + * by returning a constant). Doing so can misroute deserialization to the wrong helper, + * while overly strict constant-false behavior breaks valid deserialization. + * + * @param serializedForm the deserialized {@code SerializedLambda} expression + * @param serializedLambdaFingerprint compile-time fingerprint of one serialized lambda target + * @return expression that evaluates to {@code true} only for this exact target + */ + private Expression matchesSerializedFunctionalInterface(final Expression serializedForm, final SerializedLambdaFingerprint serializedLambdaFingerprint) { + return allMatch( + matchesSerializedFormInt(serializedForm, "getCapturedArgCount", serializedLambdaFingerprint.capturedArgCount()), + matchesSerializedFormString(serializedForm, "getCapturingClass", serializedLambdaFingerprint.capturingClass()), + matchesSerializedFormInt(serializedForm, "getImplMethodKind", serializedLambdaFingerprint.implMethodKind()), + matchesSerializedFormString(serializedForm, "getImplClass", serializedLambdaFingerprint.implClass()), + matchesSerializedFormString(serializedForm, "getImplMethodName", serializedLambdaFingerprint.implMethodName()), + matchesSerializedFormString(serializedForm, "getImplMethodSignature", serializedLambdaFingerprint.implMethodSignature()), + matchesSerializedFormString(serializedForm, "getFunctionalInterfaceClass", serializedLambdaFingerprint.functionalInterfaceClass()), + matchesSerializedFormString(serializedForm, "getFunctionalInterfaceMethodName", serializedLambdaFingerprint.functionalInterfaceMethodName()), + matchesSerializedFormString(serializedForm, "getFunctionalInterfaceMethodSignature", serializedLambdaFingerprint.functionalInterfaceMethodSignature()), + matchesSerializedFormString(serializedForm, "getInstantiatedMethodType", serializedLambdaFingerprint.instantiatedMethodType()) + ); + } + + /** + * Combines predicates with logical AND. + * + * @param expressions match predicates to combine + * @return conjunction of all predicates + * @throws IllegalArgumentException if no predicates are supplied + */ + private Expression allMatch(final Expression... expressions) { + return Arrays.stream(expressions) + .reduce(GeneralUtils::andX) + .orElseThrow(() -> new IllegalArgumentException("expressions must not be empty")); + } + + private Expression matchesSerializedFormInt(final Expression serializedForm, final String accessorName, final int expectedValue) { + return eqX(serializedLambdaAccessorCall(serializedForm, accessorName), constX(expectedValue, true)); + } + + private Expression matchesSerializedFormString(final Expression serializedForm, final String accessorName, final String expectedValue) { + return eqX(serializedLambdaAccessorCall(serializedForm, accessorName), constX(expectedValue)); + } + + /** + * Creates a direct {@link java.lang.invoke.SerializedLambda} accessor call for the generated + * deserialization dispatcher. + * <p> + * The dispatcher is emitted during bytecode generation rather than type checking, so it must + * resolve accessor targets explicitly instead of relying on a later static-type-checking pass. + * This keeps the generated bytecode on the direct invocation path and prevents accidental + * fallback to Groovy's dynamic method dispatch for JDK accessors. + * + * @param serializedForm the deserialized {@code SerializedLambda} expression + * @param accessorName the zero-argument accessor to invoke + * @return method call expression with a resolved direct-call target + * @throws IllegalArgumentException if the accessor does not exist on {@code SerializedLambda} + */ + private MethodCallExpression serializedLambdaAccessorCall(final Expression serializedForm, final String accessorName) { + MethodNode accessor = SERIALIZEDLAMBDA_TYPE.getMethod(accessorName, Parameter.EMPTY_ARRAY); + if (accessor == null) { + throw new IllegalArgumentException("Unknown SerializedLambda accessor: " + accessorName); + } + + MethodCallExpression accessorCall = callX(serializedForm, accessorName); + accessorCall.setMethodTarget(accessor); + accessorCall.putNodeMetaData(DIRECT_METHOD_CALL_TARGET, accessor); + return accessorCall; + } + + /** + * Compile-time identity of one serialized functional-interface target used by + * {@code $deserializeLambda$} dispatch. + */ + record SerializedLambdaFingerprint(String capturingClass, int implMethodKind, String implClass, String implMethodName, + String implMethodSignature, String functionalInterfaceClass, + String functionalInterfaceMethodName, String functionalInterfaceMethodSignature, + String instantiatedMethodType, int capturedArgCount) { } } diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesFunctionalInterfaceMetadataKey.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesFunctionalInterfaceMetadataKey.java new file mode 100644 index 0000000000..8deddab27a --- /dev/null +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesFunctionalInterfaceMetadataKey.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.classgen.asm.sc; + +/** + * Internal AST node metadata keys used by the statically-compiled lambda and method-reference pipeline. + * + * @since 6.0.0 + */ +enum StaticTypesFunctionalInterfaceMetadataKey { + /** + * Stores the shared {@code $deserializeLambda$} guard block on the enclosing class node. + */ + DESERIALIZE_LAMBDA_DISPATCHER, + /** + * Marks the synthetic constructor created for a generated lambda class. + */ + LAMBDA_GENERATED_CONSTRUCTOR, + /** + * Stores the captured shared variables prepared for a lambda expression. + */ + LAMBDA_SHARED_VARIABLES, + /** + * Marks the deserialize helper method that preloads a captured lambda receiver. + */ + LAMBDA_PRELOADED_RECEIVER, + /** + * Caches whether the analyzed lambda method touches enclosing-instance state. + */ + LAMBDA_ACCESSES_INSTANCE_MEMBERS, + /** + * Stores the synthetic deserialize helper name allocated for a method-reference expression so + * repeated bytecode-generation visits can reuse the same helper slot. + */ + METHOD_REFERENCE_DESERIALIZE_METHOD_NAME +} diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaAnalyzer.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaAnalyzer.java index 5509899d86..3561d25937 100644 --- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaAnalyzer.java +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaAnalyzer.java @@ -40,6 +40,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import static org.codehaus.groovy.classgen.asm.sc.StaticTypesFunctionalInterfaceMetadataKey.LAMBDA_ACCESSES_INSTANCE_MEMBERS; import static org.apache.groovy.util.BeanUtils.capitalize; import static org.codehaus.groovy.ast.tools.GeneralUtils.classX; import static org.codehaus.groovy.transform.stc.StaticTypesMarker.DIRECT_METHOD_CALL_TARGET; @@ -68,14 +69,14 @@ class StaticTypesLambdaAnalyzer { } boolean accessesInstanceMembers(final MethodNode lambdaMethod) { - Boolean accessingInstanceMembers = lambdaMethod.getNodeMetaData(MetaDataKey.ACCESSES_INSTANCE_MEMBERS); + Boolean accessingInstanceMembers = lambdaMethod.getNodeMetaData(LAMBDA_ACCESSES_INSTANCE_MEMBERS); if (accessingInstanceMembers != null) return accessingInstanceMembers; InstanceMemberAccessFinder finder = new InstanceMemberAccessFinder(getOrCreateResolver(lambdaMethod)); lambdaMethod.getCode().visit(finder); accessingInstanceMembers = finder.isAccessingInstanceMembers(); - lambdaMethod.putNodeMetaData(MetaDataKey.ACCESSES_INSTANCE_MEMBERS, accessingInstanceMembers); + lambdaMethod.putNodeMetaData(LAMBDA_ACCESSES_INSTANCE_MEMBERS, accessingInstanceMembers); return accessingInstanceMembers; } @@ -161,9 +162,7 @@ class StaticTypesLambdaAnalyzer { PropertyExpression qualifiedReference = new PropertyExpression(classX(owner), expression.getName()); qualifiedReference.setImplicitThis(false); - qualifiedReference.copyNodeMetaData(expression); - setSourcePosition(qualifiedReference, expression); - return qualifiedReference; + return finishQualifiedReference(qualifiedReference, expression); } private Expression qualify(final AttributeExpression expression) { @@ -177,9 +176,7 @@ class StaticTypesLambdaAnalyzer { ); qualifiedReference.setImplicitThis(false); qualifiedReference.setSpreadSafe(expression.isSpreadSafe()); - qualifiedReference.copyNodeMetaData(expression); - setSourcePosition(qualifiedReference, expression); - return qualifiedReference; + return finishQualifiedReference(qualifiedReference, expression); } private Expression qualify(final PropertyExpression expression) { @@ -193,9 +190,7 @@ class StaticTypesLambdaAnalyzer { ); qualifiedReference.setImplicitThis(false); qualifiedReference.setSpreadSafe(expression.isSpreadSafe()); - qualifiedReference.copyNodeMetaData(expression); - setSourcePosition(qualifiedReference, expression); - return qualifiedReference; + return finishQualifiedReference(qualifiedReference, expression); } private Expression qualify(final MethodCallExpression expression) { @@ -212,6 +207,10 @@ class StaticTypesLambdaAnalyzer { qualifiedReference.setSpreadSafe(expression.isSpreadSafe()); qualifiedReference.setGenericsTypes(expression.getGenericsTypes()); qualifiedReference.setMethodTarget(expression.getMethodTarget()); + return finishQualifiedReference(qualifiedReference, expression); + } + + private <T extends Expression> T finishQualifiedReference(final T qualifiedReference, final Expression expression) { qualifiedReference.copyNodeMetaData(expression); setSourcePosition(qualifiedReference, expression); return qualifiedReference; @@ -394,10 +393,6 @@ class StaticTypesLambdaAnalyzer { } } - private enum MetaDataKey { - ACCESSES_INSTANCE_MEMBERS - } - private final SourceUnit sourceUnit; private final Map<MethodNode, OuterStaticMemberResolver> resolverCache = new HashMap<>(); } diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java index b61dc6c930..3b393b41c9 100644 --- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java @@ -25,11 +25,8 @@ import org.codehaus.groovy.ast.ConstructorNode; import org.codehaus.groovy.ast.InnerClassNode; import org.codehaus.groovy.ast.MethodNode; import org.codehaus.groovy.ast.Parameter; -import org.codehaus.groovy.ast.builder.AstStringCompiler; import org.codehaus.groovy.ast.expr.ClosureExpression; -import org.codehaus.groovy.ast.expr.Expression; import org.codehaus.groovy.ast.expr.LambdaExpression; -import org.codehaus.groovy.ast.stmt.BlockStatement; import org.codehaus.groovy.ast.stmt.Statement; import org.codehaus.groovy.classgen.BytecodeInstruction; import org.codehaus.groovy.classgen.BytecodeSequence; @@ -50,14 +47,14 @@ import static org.codehaus.groovy.ast.ClassHelper.CLOSURE_TYPE; import static org.codehaus.groovy.ast.ClassHelper.GENERATED_LAMBDA_TYPE; import static org.codehaus.groovy.ast.ClassHelper.OBJECT_TYPE; import static org.codehaus.groovy.ast.ClassHelper.SERIALIZABLE_TYPE; -import static org.codehaus.groovy.ast.ClassHelper.SERIALIZEDLAMBDA_TYPE; import static org.codehaus.groovy.ast.ClassHelper.VOID_TYPE; import static org.codehaus.groovy.ast.tools.ClosureUtils.getParametersSafe; import static org.codehaus.groovy.ast.tools.GeneralUtils.block; import static org.codehaus.groovy.ast.tools.GeneralUtils.classX; -import static org.codehaus.groovy.ast.tools.GeneralUtils.declS; -import static org.codehaus.groovy.ast.tools.GeneralUtils.localVarX; import static org.codehaus.groovy.ast.tools.GeneralUtils.returnS; +import static org.codehaus.groovy.classgen.asm.sc.StaticTypesFunctionalInterfaceMetadataKey.LAMBDA_GENERATED_CONSTRUCTOR; +import static org.codehaus.groovy.classgen.asm.sc.StaticTypesFunctionalInterfaceMetadataKey.LAMBDA_PRELOADED_RECEIVER; +import static org.codehaus.groovy.classgen.asm.sc.StaticTypesFunctionalInterfaceMetadataKey.LAMBDA_SHARED_VARIABLES; import static org.codehaus.groovy.transform.stc.StaticTypesMarker.CLOSURE_ARGUMENTS; import static org.codehaus.groovy.transform.stc.StaticTypesMarker.PARAMETER_TYPE; import static org.objectweb.asm.Opcodes.ACC_FINAL; @@ -99,7 +96,7 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun boolean serializable = makeSerializableIfNeeded(expression, functionalType); GeneratedLambda generatedLambda = getOrAddGeneratedLambda(expression, abstractMethod); - ensureDeserializeLambdaSupport(expression, generatedLambda, serializable); + ensureDeserializeLambdaSupport(expression, functionalType, abstractMethod, generatedLambda, serializable); if (generatedLambda.isCapturing() && !isPreloadedLambdaReceiver(generatedLambda)) { loadLambdaReceiver(generatedLambda); } @@ -107,10 +104,6 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun writeLambdaFactoryInvocation(functionalType.redirect(), abstractMethod, generatedLambda, serializable); } - private static Parameter[] createDeserializeLambdaMethodParams() { - return new Parameter[]{new Parameter(SERIALIZEDLAMBDA_TYPE, "serializedLambda")}; - } - private static MethodNode resolveFunctionalInterfaceMethod(final ClassNode functionalType) { if (functionalType == null || !functionalType.isInterface()) { return null; @@ -125,28 +118,38 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun return expression.isSerializable(); } - private void ensureDeserializeLambdaSupport(final LambdaExpression expression, final GeneratedLambda generatedLambda, final boolean serializable) { + private void ensureDeserializeLambdaSupport(final LambdaExpression expression, final ClassNode functionalType, final MethodNode abstractMethod, final GeneratedLambda generatedLambda, final boolean serializable) { if (!serializable || hasDeserializeLambdaMethod(generatedLambda.lambdaClass)) { return; } - addDeserializeLambdaMethodForLambdaExpression(expression, generatedLambda); - addDeserializeLambdaMethod(); + String samMethodDescriptor = createMethodDescriptor(abstractMethod); + MethodNode helperMethod = addDeserializeLambdaMethodForLambdaExpression(expression, generatedLambda); + addDeserializeDispatcherEntry(controller, createDeserializeMethodParameters(), createSerializedLambdaFingerprint( + samMethodDescriptor, + controller.getClassNode(), + generatedLambda.getImplMethodKind(), + generatedLambda.lambdaClass, + generatedLambda.lambdaMethod, + generatedLambda.lambdaMethod.getParameters(), + functionalType, + abstractMethod, + generatedLambda.isCapturing() ? 1 : 0 + ), helperMethod); } private void writeLambdaFactoryInvocation(final ClassNode functionalType, final MethodNode abstractMethod, final GeneratedLambda generatedLambda, final boolean serializable) { - MethodVisitor mv = controller.getMethodVisitor(); - mv.visitInvokeDynamicInsn( + writeFunctionalInterfaceIndy( + controller.getMethodVisitor(), abstractMethod.getName(), createLambdaFactoryMethodDescriptor(functionalType, generatedLambda), - createBootstrapMethod(controller.getClassNode().isInterface(), serializable), - createBootstrapMethodArguments(createMethodDescriptor(abstractMethod), - generatedLambda.getMethodHandleKind(), - generatedLambda.lambdaClass, generatedLambda.lambdaMethod, generatedLambda.lambdaMethod.getParameters(), serializable) + createMethodDescriptor(abstractMethod), + generatedLambda.getImplMethodKind(), + generatedLambda.lambdaClass, + generatedLambda.lambdaMethod, + generatedLambda.lambdaMethod.getParameters(), + serializable ); - if (serializable) { - mv.visitTypeInsn(CHECKCAST, "java/io/Serializable"); - } if (generatedLambda.nonCapturing()) { controller.getOperandStack().push(functionalType); @@ -156,7 +159,7 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun } private boolean hasDeserializeLambdaMethod(final ClassNode lambdaClass) { - return controller.getClassNode().hasMethod(createDeserializeLambdaMethodName(lambdaClass), createDeserializeLambdaMethodParams()); + return controller.getClassNode().hasMethod(createDeserializeLambdaMethodName(lambdaClass), createDeserializeMethodParameters()); } private static MethodNode getLambdaMethod(final ClassNode lambdaClass) { @@ -169,7 +172,7 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun private static ConstructorNode getGeneratedConstructor(final ClassNode lambdaClass) { for (ConstructorNode constructorNode : lambdaClass.getDeclaredConstructors()) { - if (Boolean.TRUE.equals(constructorNode.getNodeMetaData(MetaDataKey.GENERATED_CONSTRUCTOR))) { + if (Boolean.TRUE.equals(constructorNode.getNodeMetaData(LAMBDA_GENERATED_CONSTRUCTOR))) { return constructorNode; } } @@ -179,7 +182,7 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun private boolean isPreloadedLambdaReceiver(final GeneratedLambda generatedLambda) { MethodNode enclosingMethod = controller.getMethodNode(); return enclosingMethod != null - && enclosingMethod.getNodeMetaData(MetaDataKey.PRELOADED_LAMBDA_RECEIVER) == generatedLambda.lambdaClass; + && enclosingMethod.getNodeMetaData(LAMBDA_PRELOADED_RECEIVER) == generatedLambda.lambdaClass; } private void loadLambdaReceiver(final GeneratedLambda generatedLambda) { @@ -202,7 +205,7 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun loadSharedVariables(generatedLambda.sharedVariables); Parameter[] lambdaClassConstructorParameters = generatedLambda.constructor.getParameters(); - mv.visitMethodInsn(INVOKESPECIAL, lambdaClassInternalName, "<init>", BytecodeHelper.getMethodDescriptor(VOID_TYPE, lambdaClassConstructorParameters), generatedLambda.lambdaClass.isInterface()); + mv.visitMethodInsn(INVOKESPECIAL, lambdaClassInternalName, "<init>", createMethodDescriptor(VOID_TYPE, lambdaClassConstructorParameters), generatedLambda.lambdaClass.isInterface()); operandStack.replace(CLOSURE_TYPE, lambdaClassConstructorParameters.length); } @@ -217,16 +220,10 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun } private String createLambdaFactoryMethodDescriptor(final ClassNode functionalInterface, final GeneratedLambda generatedLambda) { - if (generatedLambda.nonCapturing()) { - return BytecodeHelper.getMethodDescriptor(functionalInterface, Parameter.EMPTY_ARRAY); - } - return BytecodeHelper.getMethodDescriptor(functionalInterface, new Parameter[]{createLambdaReceiverParameter(generatedLambda.lambdaClass)}); - } - - private static Parameter createLambdaReceiverParameter(final ClassNode lambdaClass) { - Parameter parameter = new Parameter(lambdaClass, "__lambda_this"); - parameter.setClosureSharedVariable(false); - return parameter; + return createFunctionalInterfaceFactoryDescriptor(functionalInterface, + generatedLambda.nonCapturing() + ? Parameter.EMPTY_ARRAY + : new Parameter[]{createCapturedReceiverParameter(generatedLambda.lambdaClass, "__lambda_this")}); } private GeneratedLambda getOrAddGeneratedLambda(final LambdaExpression expression, final MethodNode abstractMethod) { @@ -279,7 +276,7 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun addFieldsForLocalVariables(lambdaClass, localVariableParameters); ConstructorNode constructorNode = addConstructor(expression, localVariableParameters, lambdaClass, createBlockStatementForConstructor(expression, outermostClass, enclosingClass)); - constructorNode.putNodeMetaData(MetaDataKey.GENERATED_CONSTRUCTOR, Boolean.TRUE); + constructorNode.putNodeMetaData(LAMBDA_GENERATED_CONSTRUCTOR, Boolean.TRUE); syntheticLambdaMethodNode.getCode().visit(new CorrectAccessedVariableVisitor(lambdaClass)); @@ -297,7 +294,7 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun Parameter[] localVariableParameters = getLambdaSharedVariables(expression); removeInitialValues(localVariableParameters); - expression.putNodeMetaData(MetaDataKey.STORED_LAMBDA_SHARED_VARIABLES, localVariableParameters); + expression.putNodeMetaData(LAMBDA_SHARED_VARIABLES, localVariableParameters); MethodNode doCallMethod = lambdaClass.addMethod( DO_CALL, @@ -327,36 +324,11 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun return lambdaParameters; } - private void addDeserializeLambdaMethod() { - ClassNode enclosingClass = controller.getClassNode(); - Parameter[] parameters = createDeserializeLambdaMethodParams(); - if (enclosingClass.hasMethod("$deserializeLambda$", parameters)) { - return; - } - - Statement code = block( - declS(localVarX("enclosingClass", OBJECT_TYPE), classX(enclosingClass)), - ((BlockStatement) new AstStringCompiler().compile( - "return enclosingClass" + - ".getDeclaredMethod(\"\\$deserializeLambda_${serializedLambda.getImplClass().replace('/', '$')}\\$\", serializedLambda.getClass())" + - ".invoke(null, serializedLambda)" - ).get(0)).getStatements().get(0) - ); - - enclosingClass.addSyntheticMethod( - "$deserializeLambda$", - ACC_PRIVATE | ACC_STATIC, - OBJECT_TYPE, - parameters, - ClassNode.EMPTY_ARRAY, - code); - } - private static boolean requiresLambdaInstance(final MethodNode lambdaMethod) { return 0 == (lambdaMethod.getModifiers() & ACC_STATIC); } - private void addDeserializeLambdaMethodForLambdaExpression(final LambdaExpression expression, final GeneratedLambda generatedLambda) { + private MethodNode addDeserializeLambdaMethodForLambdaExpression(final LambdaExpression expression, final GeneratedLambda generatedLambda) { ClassNode enclosingClass = controller.getClassNode(); Statement code; if (generatedLambda.nonCapturing()) { @@ -385,15 +357,16 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun MethodNode deserializeLambdaMethod = enclosingClass.addSyntheticMethod( createDeserializeLambdaMethodName(generatedLambda.lambdaClass), - ACC_PUBLIC | ACC_STATIC, + ACC_PRIVATE | ACC_STATIC, OBJECT_TYPE, - createDeserializeLambdaMethodParams(), + createDeserializeMethodParameters(), ClassNode.EMPTY_ARRAY, code); if (generatedLambda.isCapturing()) { // The deserialize helper preloads the captured receiver before it reuses the original lambda expression. - deserializeLambdaMethod.putNodeMetaData(MetaDataKey.PRELOADED_LAMBDA_RECEIVER, generatedLambda.lambdaClass); + deserializeLambdaMethod.putNodeMetaData(LAMBDA_PRELOADED_RECEIVER, generatedLambda.lambdaClass); } + return deserializeLambdaMethod; } private static String createDeserializeLambdaMethodName(final ClassNode lambdaClass) { @@ -401,19 +374,17 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun } private static Parameter[] getStoredLambdaSharedVariables(final LambdaExpression expression) { - Parameter[] sharedVariables = expression.getNodeMetaData(MetaDataKey.STORED_LAMBDA_SHARED_VARIABLES); + Parameter[] sharedVariables = expression.getNodeMetaData(LAMBDA_SHARED_VARIABLES); if (sharedVariables == null) { throw new GroovyBugError("Failed to find shared variables for lambda expression"); } return sharedVariables; } - private enum MetaDataKey { - GENERATED_CONSTRUCTOR, - STORED_LAMBDA_SHARED_VARIABLES, - PRELOADED_LAMBDA_RECEIVER - } - + /** + * Cached lambda generation result reused across the emitted indy call and + * any synthetic deserialization helpers for the same expression. + */ private record GeneratedLambda(ClassNode lambdaClass, MethodNode lambdaMethod, ConstructorNode constructor, Parameter[] sharedVariables, boolean nonCapturing, boolean accessingInstanceMembers) { @@ -422,7 +393,7 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun return !nonCapturing; } - private int getMethodHandleKind() { + private int getImplMethodKind() { return nonCapturing ? H_INVOKESTATIC : H_INVOKEVIRTUAL; } } diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java index c526caa7b4..6e404252cf 100644 --- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java @@ -31,8 +31,9 @@ import org.codehaus.groovy.ast.expr.Expression; import org.codehaus.groovy.ast.expr.MethodCallExpression; import org.codehaus.groovy.ast.expr.MethodReferenceExpression; import org.codehaus.groovy.ast.tools.GeneralUtils; +import org.codehaus.groovy.classgen.BytecodeInstruction; +import org.codehaus.groovy.classgen.BytecodeSequence; import org.codehaus.groovy.classgen.AsmClassGenerator; -import org.codehaus.groovy.classgen.asm.BytecodeHelper; import org.codehaus.groovy.classgen.asm.MethodReferenceExpressionWriter; import org.codehaus.groovy.classgen.asm.WriterController; import org.codehaus.groovy.control.MultipleCompilationErrorsException; @@ -40,7 +41,7 @@ import org.codehaus.groovy.syntax.RuntimeParserException; import org.codehaus.groovy.transform.sc.StaticCompilationMetadataKeys; import org.codehaus.groovy.transform.stc.ExtensionMethodNode; import org.codehaus.groovy.transform.stc.StaticTypesMarker; -import org.objectweb.asm.Handle; +import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; import java.util.Arrays; @@ -60,6 +61,8 @@ import static org.codehaus.groovy.ast.tools.GeneralUtils.returnS; import static org.codehaus.groovy.ast.tools.GeneralUtils.varX; import static org.codehaus.groovy.ast.tools.GenericsUtils.extractPlaceholders; import static org.codehaus.groovy.ast.tools.GenericsUtils.makeClassSafe0; +import static org.codehaus.groovy.classgen.asm.BytecodeHelper.getClassInternalName; +import static org.codehaus.groovy.classgen.asm.sc.StaticTypesFunctionalInterfaceMetadataKey.METHOD_REFERENCE_DESERIALIZE_METHOD_NAME; import static org.codehaus.groovy.ast.tools.ParameterUtils.isVargs; import static org.codehaus.groovy.ast.tools.ParameterUtils.parametersCompatible; import static org.codehaus.groovy.runtime.ArrayGroovyMethods.last; @@ -68,6 +71,13 @@ import static org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport.filter import static org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport.findDGMMethodsForClassNode; import static org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport.isAssignableTo; import static org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport.resolveClassNodeGenerics; +import static org.objectweb.asm.Opcodes.ACC_PRIVATE; +import static org.objectweb.asm.Opcodes.ACC_STATIC; +import static org.objectweb.asm.Opcodes.ALOAD; +import static org.objectweb.asm.Opcodes.ARETURN; +import static org.objectweb.asm.Opcodes.CHECKCAST; +import static org.objectweb.asm.Opcodes.ICONST_0; +import static org.objectweb.asm.Opcodes.INVOKEVIRTUAL; /** * Generates bytecode for method reference expressions in statically-compiled code. @@ -82,132 +92,232 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE @Override public void writeMethodReferenceExpression(final MethodReferenceExpression methodReferenceExpression) { - // functional interface target is required for native method reference generation - ClassNode functionalType = methodReferenceExpression.getNodeMetaData(StaticTypesMarker.PARAMETER_TYPE); - MethodNode abstractMethod = ClassHelper.findSAM(functionalType); - if (abstractMethod == null || !functionalType.isInterface()) { + FunctionalInterfaceContext functionalInterface = resolveFunctionalInterfaceContext(methodReferenceExpression); + if (functionalInterface == null) { // generate the default bytecode -- most likely a method closure super.writeMethodReferenceExpression(methodReferenceExpression); return; } - ClassNode classNode = controller.getClassNode(); + MethodReferenceTarget referenceTarget = resolveMethodReferenceTarget(methodReferenceExpression); + ResolvedMethodReference resolvedMethodReference = resolveMethodReference(methodReferenceExpression, functionalInterface, referenceTarget); + validate(methodReferenceExpression, referenceTarget.type(), resolvedMethodReference.methodName(), + resolvedMethodReference.implementationMethod(), functionalInterface.parametersWithExactType(), + resolveClassNodeGenerics(extractPlaceholders(functionalInterface.functionalType()), null, functionalInterface.abstractMethod().getReturnType())); + + ResolvedMethodReference adaptedMethodReference = adaptMethodReference(functionalInterface, resolvedMethodReference); + ResolvedMethodReference invocationReadyMethodReference = prepareInvocationTarget(methodReferenceExpression, adaptedMethodReference); + MethodReferenceInvocation invocation = createMethodReferenceInvocation(functionalInterface.functionalType(), invocationReadyMethodReference); + + if (functionalInterface.serializable()) { + ensureDeserializeLambdaSupport(methodReferenceExpression, functionalInterface, invocationReadyMethodReference, invocation); + } + + writeFunctionalInterfaceIndy( + controller.getMethodVisitor(), + functionalInterface.abstractMethod().getName(), + invocation.invokedTypeDescriptor(), + functionalInterface.samMethodDescriptor(), + invocation.implMethodKind(), + invocationReadyMethodReference.implementationMethod().getDeclaringClass(), + invocationReadyMethodReference.implementationMethod(), + functionalInterface.parametersWithExactType(), + functionalInterface.serializable() + ); + + updateOperandStack(functionalInterface.functionalType(), invocation.capturing()); + } + + private FunctionalInterfaceContext resolveFunctionalInterfaceContext(final MethodReferenceExpression methodReferenceExpression) { + // functional interface target is required for native method reference generation + ClassNode functionalType = methodReferenceExpression.getNodeMetaData(StaticTypesMarker.PARAMETER_TYPE); + if (functionalType == null || !functionalType.isInterface()) { + return null; + } + + MethodNode abstractMethod = ClassHelper.findSAM(functionalType); + if (abstractMethod == null) { + return null; + } + + ClassNode[] inferredParameterTypes = methodReferenceExpression.getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS); + return new FunctionalInterfaceContext( + functionalType, + abstractMethod, + createParametersWithExactType(abstractMethod, inferredParameterTypes), + createMethodDescriptor(abstractMethod), + functionalType.implementsInterface(ClassHelper.SERIALIZABLE_TYPE) + ); + } + + private MethodReferenceTarget resolveMethodReferenceTarget(final MethodReferenceExpression methodReferenceExpression) { Expression typeOrTargetRef = methodReferenceExpression.getExpression(); - boolean isClassExpression = (typeOrTargetRef instanceof ClassExpression); - boolean targetIsArgument = false; // implied argument for expr::staticMethod? - ClassNode typeOrTargetRefType = isClassExpression ? typeOrTargetRef.getType() - : controller.getTypeChooser().resolveType(typeOrTargetRef, classNode); - - if (ClassHelper.isPrimitiveType(typeOrTargetRefType)) // GROOVY-11353 - typeOrTargetRefType = ClassHelper.getWrapper(typeOrTargetRefType); - - ClassNode[] methodReferenceParamTypes = methodReferenceExpression.getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS); - Parameter[] parametersWithExactType = createParametersWithExactType(abstractMethod, methodReferenceParamTypes); - String methodRefName = methodReferenceExpression.getMethodName().getText(); - boolean isConstructorReference = isConstructorReference(methodRefName); - - MethodNode methodRefMethod; - if (isConstructorReference) { - methodRefName = controller.getContext().getNextConstructorReferenceSyntheticMethodName(controller.getMethodNode()); - methodRefMethod = addSyntheticMethodForConstructorReference(methodRefName, typeOrTargetRefType, parametersWithExactType); - } else { - // TODO: move the findMethodRefMethod and checking to StaticTypeCheckingVisitor - methodRefMethod = findMethodRefMethod(methodRefName, parametersWithExactType, typeOrTargetRef, typeOrTargetRefType); - if (methodReferenceExpression.getNodeMetaData(StaticTypesMarker.PV_METHODS_ACCESS) != null) { // GROOVY-11301, GROOVY-11365: access bridge indicated - Map<MethodNode,MethodNode> bridgeMethods = typeOrTargetRefType.redirect().getNodeMetaData(StaticCompilationMetadataKeys.PRIVATE_BRIDGE_METHODS); - if (bridgeMethods != null) methodRefMethod = bridgeMethods.getOrDefault(methodRefMethod, methodRefMethod); // bridge may not have been generated + boolean classExpression = typeOrTargetRef instanceof ClassExpression; + ClassNode targetType = classExpression + ? typeOrTargetRef.getType() + : controller.getTypeChooser().resolveType(typeOrTargetRef, controller.getClassNode()); + + if (ClassHelper.isPrimitiveType(targetType)) { // GROOVY-11353 + targetType = ClassHelper.getWrapper(targetType); + } + + return new MethodReferenceTarget(typeOrTargetRef, targetType, classExpression, false); + } + + private ResolvedMethodReference resolveMethodReference(final MethodReferenceExpression methodReferenceExpression, + final FunctionalInterfaceContext functionalInterface, + final MethodReferenceTarget referenceTarget) { + String methodName = methodReferenceExpression.getMethodName().getText(); + if (isConstructorReference(methodName)) { + String syntheticMethodName = controller.getContext().getNextConstructorReferenceSyntheticMethodName(controller.getMethodNode()); + MethodNode constructorReferenceMethod = addSyntheticMethodForConstructorReference( + syntheticMethodName, + referenceTarget.type(), + functionalInterface.parametersWithExactType() + ); + return new ResolvedMethodReference(referenceTarget, syntheticMethodName, constructorReferenceMethod, true); + } + + return new ResolvedMethodReference( + referenceTarget, + methodName, + findMethodReferenceImplementation(methodReferenceExpression, methodName, functionalInterface.parametersWithExactType(), referenceTarget), + false + ); + } + + private MethodNode findMethodReferenceImplementation(final MethodReferenceExpression methodReferenceExpression, final String methodName, + final Parameter[] samParameters, final MethodReferenceTarget referenceTarget) { + // TODO: move the method lookup and validation to StaticTypeCheckingVisitor + MethodNode methodRefMethod = findMethodRefMethod(methodName, samParameters, referenceTarget.expression(), referenceTarget.type()); + if (methodReferenceExpression.getNodeMetaData(StaticTypesMarker.PV_METHODS_ACCESS) != null) { // GROOVY-11301, GROOVY-11365: access bridge indicated + Map<MethodNode, MethodNode> bridgeMethods = referenceTarget.type().redirect().getNodeMetaData(StaticCompilationMetadataKeys.PRIVATE_BRIDGE_METHODS); + if (bridgeMethods != null) { + methodRefMethod = bridgeMethods.getOrDefault(methodRefMethod, methodRefMethod); // bridge may not have been generated } - if (methodRefMethod == null && isClassExpression) { - var classValue = varX("_class_", typeOrTargetRefType); - var classClass = makeClassSafe0(ClassHelper.CLASS_Type, new GenericsType(typeOrTargetRefType)); - methodRefMethod = findMethodRefMethod(methodRefName, parametersWithExactType, classValue, classClass); - if (methodRefMethod != null) methodRefMethod = addSyntheticMethodForClassReference(methodRefMethod, typeOrTargetRefType); + } + if (methodRefMethod == null && referenceTarget.classExpression()) { + Expression classValue = varX("_class_", referenceTarget.type()); + ClassNode classClass = makeClassSafe0(ClassHelper.CLASS_Type, new GenericsType(referenceTarget.type())); + methodRefMethod = findMethodRefMethod(methodName, samParameters, classValue, classClass); + if (methodRefMethod != null) { + methodRefMethod = addSyntheticMethodForClassReference(methodRefMethod, referenceTarget.type()); } } + return methodRefMethod; + } - validate(methodReferenceExpression, typeOrTargetRefType, methodRefName, methodRefMethod, parametersWithExactType, - resolveClassNodeGenerics(extractPlaceholders(functionalType), null, abstractMethod.getReturnType())); + private ResolvedMethodReference adaptMethodReference(final FunctionalInterfaceContext functionalInterface, + final ResolvedMethodReference resolvedMethodReference) { + MethodReferenceTarget referenceTarget = resolvedMethodReference.referenceTarget(); + MethodNode methodRefMethod = resolvedMethodReference.implementationMethod(); if (isBridgeMethod(methodRefMethod)) { - targetIsArgument = true; // GROOVY-11301, GROOVY-11365 - if (isClassExpression) { // method expects an instance argument + referenceTarget = referenceTarget.markTargetAsArgument(); // GROOVY-11301, GROOVY-11365 + if (referenceTarget.classExpression()) { // method expects an instance argument methodRefMethod = addSyntheticMethodForDGSM(methodRefMethod); } - } else if (isExtensionMethod(methodRefMethod)) { + return resolvedMethodReference.with(referenceTarget, methodRefMethod); + } + + if (isExtensionMethod(methodRefMethod)) { ExtensionMethodNode extensionMethodNode = (ExtensionMethodNode) methodRefMethod; - methodRefMethod = extensionMethodNode.getExtensionMethodNode(); - boolean isStatic = extensionMethodNode.isStaticExtension(); - if (isStatic) { // create adapter method to pass extra argument + methodRefMethod = extensionMethodNode.getExtensionMethodNode(); + boolean staticExtension = extensionMethodNode.isStaticExtension(); + if (staticExtension) { // create adapter method to pass extra argument methodRefMethod = addSyntheticMethodForDGSM(methodRefMethod); } - if (isStatic || isClassExpression) { - // replace expression with declaring type - typeOrTargetRefType = methodRefMethod.getDeclaringClass(); - typeOrTargetRef = makeClassTarget(typeOrTargetRefType, typeOrTargetRef); + if (staticExtension || referenceTarget.classExpression()) { + referenceTarget = referenceTarget.asClassTarget(methodRefMethod.getDeclaringClass()); } else { // GROOVY-10653 - targetIsArgument = true; // ex: "string"::size - } - } else if (isVargs(methodRefMethod.getParameters())) { - int mParameters = abstractMethod.getParameters().length; - int nParameters = methodRefMethod.getParameters().length; - if (isTypeReferringInstanceMethod(typeOrTargetRef, methodRefMethod)) nParameters += 1; - if (mParameters > nParameters || mParameters == nParameters-1 || (mParameters == nParameters - && !isAssignableTo(last(parametersWithExactType).getType(), last(methodRefMethod.getParameters()).getType()))) { - // GROOVY-9813: reference to variadic method which needs adapter method to match runtime arguments to its parameters - if (!isClassExpression && !methodRefMethod.isStatic() && !methodRefMethod.getDeclaringClass().equals(classNode)) { - targetIsArgument = true; // GROOVY-10653: create static adapter in source class with target as first parameter - mParameters += 1; - } - methodRefMethod = addSyntheticMethodForVariadicReference(methodRefMethod, mParameters, isClassExpression || targetIsArgument); - if (methodRefMethod.isStatic() && !targetIsArgument) { - // replace expression with declaring type - typeOrTargetRefType = methodRefMethod.getDeclaringClass(); - typeOrTargetRef = makeClassTarget(typeOrTargetRefType, typeOrTargetRef); - } + referenceTarget = referenceTarget.markTargetAsArgument(); // ex: "string"::size } + return resolvedMethodReference.with(referenceTarget, methodRefMethod); } - if (!isClassExpression) { - if (isConstructorReference) { // TODO: move this check to the parser - addFatalError("Constructor reference must be TypeName::new", methodReferenceExpression); - } else if (methodRefMethod.isStatic() && !targetIsArgument) { - // "string"::valueOf refers to static method, so instance is superfluous - typeOrTargetRef = makeClassTarget(typeOrTargetRefType, typeOrTargetRef); - isClassExpression = true; - } else { - typeOrTargetRef.visit(controller.getAcg()); - controller.getOperandStack().box(); // GROOVY-11353 + if (needsVariadicAdapter(functionalInterface, referenceTarget, methodRefMethod)) { + int samParameterCount = functionalInterface.abstractMethod().getParameters().length; + if (!referenceTarget.classExpression() && !methodRefMethod.isStatic() && !methodRefMethod.getDeclaringClass().equals(controller.getClassNode())) { + referenceTarget = referenceTarget.markTargetAsArgument(); // GROOVY-10653: create static adapter in source class with target as first parameter + samParameterCount += 1; + } + methodRefMethod = addSyntheticMethodForVariadicReference(methodRefMethod, samParameterCount, + referenceTarget.classExpression() || referenceTarget.targetIsArgument()); + if (methodRefMethod.isStatic() && !referenceTarget.targetIsArgument()) { + referenceTarget = referenceTarget.asClassTarget(methodRefMethod.getDeclaringClass()); } } - int referenceKind; - if (isConstructorReference || methodRefMethod.isStatic()) { - referenceKind = Opcodes.H_INVOKESTATIC; + return resolvedMethodReference.with(referenceTarget, methodRefMethod); + } + + private boolean needsVariadicAdapter(final FunctionalInterfaceContext functionalInterface, + final MethodReferenceTarget referenceTarget, + final MethodNode methodRefMethod) { + if (!isVargs(methodRefMethod.getParameters())) { + return false; + } + + int samParameterCount = functionalInterface.abstractMethod().getParameters().length; + int methodParameterCount = methodRefMethod.getParameters().length; + if (isTypeReferringInstanceMethod(referenceTarget.expression(), methodRefMethod)) { + methodParameterCount += 1; + } + + return samParameterCount > methodParameterCount + || samParameterCount == methodParameterCount - 1 + || (samParameterCount == methodParameterCount + && !isAssignableTo(last(functionalInterface.parametersWithExactType()).getType(), last(methodRefMethod.getParameters()).getType())); + } + + private ResolvedMethodReference prepareInvocationTarget(final MethodReferenceExpression methodReferenceExpression, + final ResolvedMethodReference resolvedMethodReference) { + MethodReferenceTarget referenceTarget = resolvedMethodReference.referenceTarget(); + if (referenceTarget.classExpression()) { + return resolvedMethodReference; + } + + if (resolvedMethodReference.constructorReference()) { // TODO: move this check to the parser + addFatalError("Constructor reference must be TypeName::new", methodReferenceExpression); + } else if (resolvedMethodReference.implementationMethod().isStatic() && !referenceTarget.targetIsArgument()) { + // "string"::valueOf refers to static method, so the bound instance is superfluous. + return resolvedMethodReference.withTarget(referenceTarget.asClassTarget(referenceTarget.type())); + } else { + referenceTarget.expression().visit(controller.getAcg()); + controller.getOperandStack().box(); // GROOVY-11353 + } + + return resolvedMethodReference; + } + + private MethodReferenceInvocation createMethodReferenceInvocation(final ClassNode functionalType, + final ResolvedMethodReference resolvedMethodReference) { + MethodNode methodRefMethod = resolvedMethodReference.implementationMethod(); + int implMethodKind; + if (resolvedMethodReference.constructorReference() || methodRefMethod.isStatic()) { + implMethodKind = Opcodes.H_INVOKESTATIC; } else if (methodRefMethod.getDeclaringClass().isInterface()) { - referenceKind = Opcodes.H_INVOKEINTERFACE; // GROOVY-9853 + implMethodKind = Opcodes.H_INVOKEINTERFACE; // GROOVY-9853 } else { - referenceKind = Opcodes.H_INVOKEVIRTUAL; + implMethodKind = Opcodes.H_INVOKEVIRTUAL; } - String methodName = abstractMethod.getName(); - String methodDesc = BytecodeHelper.getMethodDescriptor(functionalType.redirect(), - isClassExpression ? Parameter.EMPTY_ARRAY : new Parameter[]{new Parameter(typeOrTargetRefType, "__METHODREF_EXPR_INSTANCE")}); - - Handle bootstrapMethod = createBootstrapMethod(classNode.isInterface(), false); - Object[] bootstrapArgs = createBootstrapMethodArguments( - createMethodDescriptor(abstractMethod), - referenceKind, - methodRefMethod.getDeclaringClass(), - methodRefMethod, - parametersWithExactType, - false + MethodReferenceTarget referenceTarget = resolvedMethodReference.referenceTarget(); + Parameter[] capturedParameters = referenceTarget.classExpression() + ? Parameter.EMPTY_ARRAY + : new Parameter[]{createCapturedReceiverParameter(referenceTarget.type(), "__METHODREF_EXPR_INSTANCE")}; + return new MethodReferenceInvocation( + implMethodKind, + createFunctionalInterfaceFactoryDescriptor(functionalType, capturedParameters), + referenceTarget.isCapturing() ); - controller.getMethodVisitor().visitInvokeDynamicInsn(methodName, methodDesc, bootstrapMethod, bootstrapArgs); + } - if (isClassExpression) { - controller.getOperandStack().push(functionalType); - } else { + private void updateOperandStack(final ClassNode functionalType, final boolean capturing) { + if (capturing) { controller.getOperandStack().replace(functionalType, 1); + } else { + controller.getOperandStack().push(functionalType); } } @@ -255,7 +365,7 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE methodCall.setMethodTarget(mn); methodCall.putNodeMetaData(StaticTypesMarker.DIRECT_METHOD_CALL_TARGET, mn); - String methodName = "class$" + classType.getNameWithoutPackage() + "$" + mn.getName() + "$" + System.nanoTime(); + String methodName = createSyntheticMethodName("class", classType, mn.getName()); ClassNode returnType = resolveClassNodeGenerics(Map.of(new GenericsType.GenericsTypeName("T"), new GenericsType(classType)), null, mn.getReturnType()); @@ -297,7 +407,7 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE methodCall.setMethodTarget(mn); methodCall.putNodeMetaData(StaticTypesMarker.DIRECT_METHOD_CALL_TARGET, mn); - String methodName = "adapt$" + mn.getDeclaringClass().getNameWithoutPackage() + "$" + mn.getName() + "$" + System.nanoTime(); + String methodName = createSyntheticMethodName("adapt", mn.getDeclaringClass(), mn.getName()); MethodNode delegateMethod = addSyntheticMethod(methodName, mn.getReturnType(), methodCall, parameters, mn.getExceptions()); if (!isStaticTarget && !mn.isStatic()) delegateMethod.setModifiers(delegateMethod.getModifiers() & ~Opcodes.ACC_STATIC); @@ -346,7 +456,7 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE if (inferredParamType == null) continue; Parameter parameter = parameters[i]; - ClassNode type = convertParameterType(parameter.getType(), inferredParamType); + ClassNode type = convertParameterType(parameter.getType(), parameter.getType(), inferredParamType); parameter.setOriginType(type); parameter.setType(type); } @@ -454,6 +564,90 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE throw new MultipleCompilationErrorsException(controller.getSourceUnit().getErrorCollector()); } + private void ensureDeserializeLambdaSupport(final MethodReferenceExpression methodReferenceExpression, + final FunctionalInterfaceContext functionalInterface, + final ResolvedMethodReference resolvedMethodReference, + final MethodReferenceInvocation invocation) { + String helperName = getOrCreateDeserializeLambdaMethodName(methodReferenceExpression); + Parameter[] parameters = createDeserializeMethodParameters(); + if (controller.getClassNode().hasMethod(helperName, parameters)) { + return; + } + + MethodReferenceTarget referenceTarget = resolvedMethodReference.referenceTarget(); + MethodNode methodRefMethod = resolvedMethodReference.implementationMethod(); + MethodNode helperMethod = addDeserializeLambdaMethodForMethodReference( + helperName, + functionalInterface.abstractMethod(), + methodRefMethod, + functionalInterface.parametersWithExactType(), + invocation.implMethodKind(), + invocation.capturing(), + referenceTarget.type(), + invocation.invokedTypeDescriptor(), + functionalInterface.samMethodDescriptor() + ); + addDeserializeDispatcherEntry(controller, parameters, + createSerializedLambdaFingerprint(functionalInterface.samMethodDescriptor(), controller.getClassNode(), invocation.implMethodKind(), + methodRefMethod.getDeclaringClass(), methodRefMethod, + functionalInterface.parametersWithExactType(), functionalInterface.functionalType(), + functionalInterface.abstractMethod(), invocation.capturing() ? 1 : 0), + helperMethod); + } + + private String getOrCreateDeserializeLambdaMethodName(final MethodReferenceExpression methodReferenceExpression) { + String helperName = methodReferenceExpression.getNodeMetaData(METHOD_REFERENCE_DESERIALIZE_METHOD_NAME); + if (helperName == null) { + helperName = createDeserializeLambdaMethodName(); + methodReferenceExpression.putNodeMetaData(METHOD_REFERENCE_DESERIALIZE_METHOD_NAME, helperName); + } + return helperName; + } + + private MethodNode addDeserializeLambdaMethodForMethodReference(final String methodName, final MethodNode abstractMethod, + final MethodNode methodRefMethod, final Parameter[] parametersWithExactType, + final int implMethodKind, final boolean capturing, + final ClassNode capturedTargetType, final String invokedTypeDescriptor, + final String samMethodDescriptor) { + return controller.getClassNode().addSyntheticMethod( + methodName, + ACC_PRIVATE | ACC_STATIC, + ClassHelper.OBJECT_TYPE, + createDeserializeMethodParameters(), + ClassNode.EMPTY_ARRAY, + new BytecodeSequence(new BytecodeInstruction() { + @Override + public void visit(final MethodVisitor mv) { + if (capturing) { + mv.visitVarInsn(ALOAD, 0); + mv.visitInsn(ICONST_0); + mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/invoke/SerializedLambda", "getCapturedArg", "(I)Ljava/lang/Object;", false); + mv.visitTypeInsn(CHECKCAST, getClassInternalName(capturedTargetType)); + } + writeFunctionalInterfaceIndy( + mv, + abstractMethod.getName(), + invokedTypeDescriptor, + samMethodDescriptor, + implMethodKind, + methodRefMethod.getDeclaringClass(), + methodRefMethod, + parametersWithExactType, + true + ); + mv.visitInsn(ARETURN); + } + })); + } + + private String createSyntheticMethodName(final String prefix, final ClassNode owner, final String name) { + return prefix + "$" + owner.getNameWithoutPackage() + "$" + name + "$" + controller.getNextHelperMethodIndex(); + } + + private String createDeserializeLambdaMethodName() { + return "$deserializeLambda_methodref$" + controller.getNextHelperMethodIndex() + "$"; + } + //-------------------------------------------------------------------------- private static boolean isBridgeMethod(final MethodNode mn) { @@ -487,4 +681,52 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE private static Parameter[] removeFirstParameter(final Parameter[] parameters) { return Arrays.copyOfRange(parameters, 1, parameters.length); } + + /** + * Captures the functional-interface side of a method reference after type + * inference has fixed the SAM signature. + */ + private record FunctionalInterfaceContext(ClassNode functionalType, MethodNode abstractMethod, + Parameter[] parametersWithExactType, String samMethodDescriptor, + boolean serializable) { + } + + /** + * Models the source-side target of a method reference, including whether + * the target must still be captured at runtime. + */ + private record MethodReferenceTarget(Expression expression, ClassNode type, boolean classExpression, boolean targetIsArgument) { + private MethodReferenceTarget markTargetAsArgument() { + return new MethodReferenceTarget(expression, type, classExpression, true); + } + + private MethodReferenceTarget asClassTarget(final ClassNode targetType) { + return new MethodReferenceTarget(makeClassTarget(targetType, expression), targetType, true, targetIsArgument); + } + + private boolean isCapturing() { + return !classExpression; + } + } + + /** + * Selected implementation method together with the possibly rewritten + * method-reference target used to invoke it. + */ + private record ResolvedMethodReference(MethodReferenceTarget referenceTarget, String methodName, + MethodNode implementationMethod, boolean constructorReference) { + private ResolvedMethodReference with(final MethodReferenceTarget updatedTarget, final MethodNode updatedMethod) { + return new ResolvedMethodReference(updatedTarget, methodName, updatedMethod, constructorReference); + } + + private ResolvedMethodReference withTarget(final MethodReferenceTarget updatedTarget) { + return new ResolvedMethodReference(updatedTarget, methodName, implementationMethod, constructorReference); + } + } + + /** + * Bytecode-level invocation details derived from the resolved method reference. + */ + private record MethodReferenceInvocation(int implMethodKind, String invokedTypeDescriptor, boolean capturing) { + } } diff --git a/src/test/groovy/groovy/transform/stc/LambdaTest.groovy b/src/test/groovy/groovy/transform/stc/LambdaTest.groovy index 37eab1575c..d449f7f3f4 100644 --- a/src/test/groovy/groovy/transform/stc/LambdaTest.groovy +++ b/src/test/groovy/groovy/transform/stc/LambdaTest.groovy @@ -3423,6 +3423,6 @@ final class LambdaTest { import java.util.function.IntUnaryOperator import java.util.function.Supplier '''.stripIndent() - private static final String SERIALIZED_LAMBDA_GET_CAPTURED_ARG = 'INVOKEVIRTUAL java/lang/invoke/SerializedLambda.getCapturedArg' + private static final String SERIALIZED_LAMBDA_GET_CAPTURED_ARG = 'INVOKEVIRTUAL java/lang/invoke/SerializedLambda.getCapturedArg (I)Ljava/lang/Object;' } } diff --git a/src/test/groovy/groovy/transform/stc/MethodReferenceTest.groovy b/src/test/groovy/groovy/transform/stc/MethodReferenceTest.groovy index 3c6b5e1b90..6b60ac7bf4 100644 --- a/src/test/groovy/groovy/transform/stc/MethodReferenceTest.groovy +++ b/src/test/groovy/groovy/transform/stc/MethodReferenceTest.groovy @@ -1158,6 +1158,144 @@ final class MethodReferenceTest { ''' } + @Test // class::new + void testSerializableConstructorReference() { + assertScript shell, ''' + import java.io.ByteArrayInputStream + import java.io.ByteArrayOutputStream + import java.io.Serializable + + @CompileStatic + class C { + interface SerFunc<I, O> extends Serializable, Function<I, O> {} + + static class Box { + final String value + + Box(String value) { + this.value = value.trim() + } + } + + static SerFunc<String, Box> create() { + Box::new + } + + static byte[] serialize(Serializable value) { + def out = new ByteArrayOutputStream() + out.withObjectOutputStream { it.writeObject(value) } + out.toByteArray() + } + + static <T> T deserialize(byte[] bytes) { + new ByteArrayInputStream(bytes).withObjectInputStream(C.classLoader) { + (T) it.readObject() + } + } + } + + assert C.declaredMethods.count { it.name == '$deserializeLambda$' } == 1 + + C.SerFunc<String, C.Box> factory = C.deserialize(C.serialize(C.create())) + assert factory instanceof Serializable + assert factory.apply(' ok ').value == 'ok' + ''' + } + + @Test // arrayClass::new + void testSerializableArrayConstructorReference() { + assertScript shell, ''' + import java.io.ByteArrayInputStream + import java.io.ByteArrayOutputStream + import java.io.Serializable + + @CompileStatic + class C { + interface SerIntFunc<T> extends Serializable, IntFunction<T> {} + + static SerIntFunc<String[]> create() { + String[]::new + } + + static byte[] serialize(Serializable value) { + def out = new ByteArrayOutputStream() + out.withObjectOutputStream { it.writeObject(value) } + out.toByteArray() + } + + static <T> T deserialize(byte[] bytes) { + new ByteArrayInputStream(bytes).withObjectInputStream(C.classLoader) { + (T) it.readObject() + } + } + } + + C.SerIntFunc<String[]> factory = C.deserialize(C.serialize(C.create())) + String[] values = factory.apply(3) + assert values.length == 3 + assert values.toList() == [null, null, null] + ''' + } + + @Test + void testSerializableConstructorReferencesShareDeserializeDispatcherWithLambdas() { + assertScript shell, ''' + import java.io.ByteArrayInputStream + import java.io.ByteArrayOutputStream + import java.io.Serializable + + @CompileStatic + class C { + interface SerFunc<I, O> extends Serializable, Function<I, O> {} + interface SerIntFunc<T> extends Serializable, IntFunction<T> {} + + static class Box { + final String value + + Box(String value) { + this.value = value + } + } + + static SerFunc<String, Box> createConstructorReference() { + Box::new + } + + static SerIntFunc<Box[]> createArrayConstructorReference() { + Box[]::new + } + + static SerFunc<Integer, String> createLambda() { + (Integer i) -> 'L' + i + } + + static byte[] serialize(Serializable value) { + def out = new ByteArrayOutputStream() + out.withObjectOutputStream { it.writeObject(value) } + out.toByteArray() + } + + static <T> T deserialize(byte[] bytes) { + new ByteArrayInputStream(bytes).withObjectInputStream(C.classLoader) { + (T) it.readObject() + } + } + } + + assert C.declaredMethods.count { it.name == '$deserializeLambda$' } == 1 + assert C.declaredMethods.findAll { it.name.startsWith('$deserializeLambda') && it.name != '$deserializeLambda$' } + .every { java.lang.reflect.Modifier.isPrivate(it.modifiers) && java.lang.reflect.Modifier.isStatic(it.modifiers) } + + C.SerFunc<String, C.Box> ctor = C.deserialize(C.serialize(C.createConstructorReference())) + C.SerIntFunc<C.Box[]> arrayCtor = C.deserialize(C.serialize(C.createArrayConstructorReference())) + C.SerFunc<Integer, String> lambda = C.deserialize(C.serialize(C.createLambda())) + + assert ctor.apply('box').value == 'box' + assert arrayCtor.apply(2).length == 2 + assert lambda.apply(4) == 'L4' + ''' + } + @Test // class::staticMethod void testFunctionCS() { assertScript shell, ''' @@ -1662,6 +1800,283 @@ final class MethodReferenceTest { ''' } + @Test + void testSerializableNonCapturingMethodReference() { + assertScript shell, ''' + import java.io.ByteArrayInputStream + import java.io.ByteArrayOutputStream + import java.io.Serializable + + @CompileStatic + class C { + interface SerFunc<I, O> extends Serializable, Function<I, O> {} + + static SerFunc<Integer, String> create() { + Integer::toString + } + + static byte[] serialize(Serializable value) { + def out = new ByteArrayOutputStream() + out.withObjectOutputStream { it.writeObject(value) } + out.toByteArray() + } + + static <T> T deserialize(byte[] bytes) { + new ByteArrayInputStream(bytes).withObjectInputStream(C.classLoader) { + (T) it.readObject() + } + } + } + + assert C.declaredMethods.count { it.name == '$deserializeLambda$' } == 1 + + C.SerFunc<Integer, String> fn = C.deserialize(C.serialize(C.create())) + assert fn instanceof Serializable + assert fn.apply(7) == '7' + ''' + } + + @Test + void testSerializableCapturingMethodReference() { + assertScript shell, ''' + import java.io.ByteArrayInputStream + import java.io.ByteArrayOutputStream + import java.io.Serializable + + @CompileStatic + class C { + interface SerSupplier<T> extends Serializable, Supplier<T> {} + + private final String text + + C(String text) { + this.text = text + } + + SerSupplier<String> create() { + text::trim + } + + static byte[] serialize(Serializable value) { + def out = new ByteArrayOutputStream() + out.withObjectOutputStream { it.writeObject(value) } + out.toByteArray() + } + + static <T> T deserialize(byte[] bytes) { + new ByteArrayInputStream(bytes).withObjectInputStream(C.classLoader) { + (T) it.readObject() + } + } + } + + C.SerSupplier<String> supplier = C.deserialize(C.serialize(new C(' answer ').create())) + assert supplier instanceof Serializable + assert supplier.get() == 'answer' + ''' + } + + @Test + void testSerializableMethodReferencesShareDeserializeDispatcherWithLambdas() { + assertScript shell, ''' + import java.io.ByteArrayInputStream + import java.io.ByteArrayOutputStream + import java.io.Serializable + + @CompileStatic + class C { + interface SerFunc<I, O> extends Serializable, Function<I, O> {} + interface SerSupplier<T> extends Serializable, Supplier<T> {} + + private final String text + + C(String text) { + this.text = text + } + + static SerFunc<Integer, String> createMethodReference() { + Integer::toString + } + + static SerFunc<Integer, String> createLambda() { + (Integer i) -> 'L' + i + } + + SerSupplier<String> createBoundMethodReference() { + text::trim + } + + static byte[] serialize(Serializable value) { + def out = new ByteArrayOutputStream() + out.withObjectOutputStream { it.writeObject(value) } + out.toByteArray() + } + + static <T> T deserialize(byte[] bytes) { + new ByteArrayInputStream(bytes).withObjectInputStream(C.classLoader) { + (T) it.readObject() + } + } + } + + assert C.declaredMethods.count { it.name == '$deserializeLambda$' } == 1 + assert C.declaredMethods.findAll { it.name.startsWith('$deserializeLambda') && it.name != '$deserializeLambda$' } + .every { java.lang.reflect.Modifier.isPrivate(it.modifiers) && java.lang.reflect.Modifier.isStatic(it.modifiers) } + + C.SerFunc<Integer, String> methodRef = C.deserialize(C.serialize(C.createMethodReference())) + C.SerFunc<Integer, String> lambda = C.deserialize(C.serialize(C.createLambda())) + C.SerSupplier<String> bound = C.deserialize(C.serialize(new C(' x ').createBoundMethodReference())) + + assert methodRef.apply(3) == '3' + assert lambda.apply(4) == 'L4' + assert bound.get() == 'x' + ''' + } + + @Test + void testDeserializeDispatcherReturnsMatchingMethodReferenceAndLambdaBeforeFallback() { + assertScript shell, ''' + import java.io.Serializable + import java.lang.invoke.SerializedLambda + + @CompileStatic + class C { + interface SerFunc<I, O> extends Serializable, Function<I, O> {} + + static SerFunc<Integer, String> createMethodReference() { + Integer::toString + } + + static SerFunc<Integer, String> createLambda() { + (Integer i) -> 'L' + i + } + + @CompileDynamic + static SerializedLambda serialized(Serializable value) { + def writeReplace = value.class.getDeclaredMethod('writeReplace') + writeReplace.accessible = true + (SerializedLambda) writeReplace.invoke(value) + } + } + + def dispatcher = C.getDeclaredMethod('$deserializeLambda$', SerializedLambda) + dispatcher.accessible = true + + C.SerFunc<Integer, String> methodRef = + (C.SerFunc<Integer, String>) dispatcher.invoke(null, C.serialized(C.createMethodReference())) + C.SerFunc<Integer, String> lambda = + (C.SerFunc<Integer, String>) dispatcher.invoke(null, C.serialized(C.createLambda())) + + assert methodRef.apply(3) == '3' + assert lambda.apply(3) == 'L3' + ''' + } + + @Test + void testDeserializeDispatcherRejectsWrongCapturingClassEvenWhenOtherSerializedFieldsMatch() { + assertScript shell, ''' + import java.io.Serializable + import java.lang.invoke.SerializedLambda + + @CompileStatic + class C { + interface SerFunc<I, O> extends Serializable, Function<I, O> {} + + static SerFunc<Integer, String> createMethodReference() { + Integer::toString + } + + static SerFunc<Integer, String> createLambda() { + (Integer i) -> 'L' + i + } + + @CompileDynamic + static SerializedLambda serialized(Serializable value) { + def writeReplace = value.class.getDeclaredMethod('writeReplace') + writeReplace.accessible = true + (SerializedLambda) writeReplace.invoke(value) + } + + @CompileDynamic + static SerializedLambda withCapturingClass(SerializedLambda serialized, Class capturingClass) { + new SerializedLambda( + capturingClass, + serialized.functionalInterfaceClass, + serialized.functionalInterfaceMethodName, + serialized.functionalInterfaceMethodSignature, + serialized.implMethodKind, + serialized.implClass, + serialized.implMethodName, + serialized.implMethodSignature, + serialized.instantiatedMethodType, + (0..<serialized.capturedArgCount).collect { serialized.getCapturedArg(it) } as Object[] + ) + } + } + + def dispatcher = C.getDeclaredMethod('$deserializeLambda$', SerializedLambda) + dispatcher.accessible = true + + def assertInvalid = { SerializedLambda serialized -> + def err + try { + dispatcher.invoke(null, serialized) + assert false: 'dispatcher invocation should fail' + } catch (java.lang.reflect.InvocationTargetException e) { + err = e + } + assert err.cause instanceof IllegalArgumentException + assert err.cause.message == 'Invalid serialized functional interface' + } + + assertInvalid(C.withCapturingClass(C.serialized(C.createMethodReference()), String)) + assertInvalid(C.withCapturingClass(C.serialized(C.createLambda()), Integer)) + ''' + } + + @Test + void testDeserializeDispatcherReportsClearErrorForMismatchedSerializedForm() { + assertScript shell, ''' + import java.lang.invoke.MethodHandleInfo + import java.lang.invoke.SerializedLambda + + @CompileStatic + class C { + interface SerFunc<I, O> extends Serializable, Function<I, O> {} + + static SerFunc<Integer, String> create() { + Integer::toString + } + } + + def dispatcher = C.getDeclaredMethod('$deserializeLambda$', SerializedLambda) + dispatcher.accessible = true + + def serialized = new SerializedLambda( + C, + 'java/util/function/Function', + 'apply', + '(Ljava/lang/Object;)Ljava/lang/Object;', + MethodHandleInfo.REF_invokeStatic, + 'java/lang/Integer', + 'toString', + '(I)Ljava/lang/String;', + '(Ljava/lang/Integer;)Ljava/lang/String;', + [] as Object[] + ) + + def err + try { + dispatcher.invoke(null, serialized) + assert false: 'dispatcher invocation should fail' + } catch (java.lang.reflect.InvocationTargetException e) { + err = e + } + assert err.cause instanceof IllegalArgumentException + assert err.cause.message == 'Invalid serialized functional interface' + ''' + } + // GROOVY-11467 @Test void testSuperInterfaceMethodReference() {
