This is an automated email from the ASF dual-hosted git repository. paulk-asert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/groovy.git
commit f7a1256bbff48405ab5572fbfde18d8922d7fb0a Author: Paul King <[email protected]> AuthorDate: Wed May 6 17:51:33 2026 +1000 GROOVY-11998: Better support of intersection types (part 3) Bytecode for lambdas/method refs. writeFunctionalInterfaceIndy markers, altMetafactory flags. Static parity with Java's (R & S) ()->…. --- .../groovy/classgen/AsmClassGenerator.java | 16 +- .../asm/sc/AbstractFunctionalInterfaceWriter.java | 89 +++++++- .../classgen/asm/sc/StaticTypesLambdaWriter.java | 27 ++- ...StaticTypesMethodReferenceExpressionWriter.java | 28 ++- .../groovy/lang/IntersectionCastE2ETest.groovy | 183 ++++++++++++++++ .../transform/stc/IntersectionCastSTCTest.groovy | 243 +++++++++++++++++++++ 6 files changed, 567 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/codehaus/groovy/classgen/AsmClassGenerator.java b/src/main/java/org/codehaus/groovy/classgen/AsmClassGenerator.java index 1e3af7d388..c98b943da6 100644 --- a/src/main/java/org/codehaus/groovy/classgen/AsmClassGenerator.java +++ b/src/main/java/org/codehaus/groovy/classgen/AsmClassGenerator.java @@ -56,6 +56,7 @@ import org.codehaus.groovy.ast.expr.EmptyExpression; import org.codehaus.groovy.ast.expr.Expression; import org.codehaus.groovy.ast.expr.FieldExpression; import org.codehaus.groovy.ast.expr.GStringExpression; +import org.codehaus.groovy.ast.IntersectionTypeClassNode; import org.codehaus.groovy.ast.expr.LambdaExpression; import org.codehaus.groovy.ast.expr.ListExpression; import org.codehaus.groovy.ast.expr.MapEntryExpression; @@ -1000,9 +1001,22 @@ public class AsmClassGenerator extends ClassGenerator { @Override public void visitCastExpression(final CastExpression castExpression) { Expression expression = castExpression.getExpression(); + ClassNode type = castExpression.getType(); + + // GROOVY-11998: lambda / method-reference factory invocations already + // emit an object that implements every component of an intersection + // target via altMetafactory FLAG_MARKERS, so the outer cast is a no-op + // here. Non-functional intersection casts are out of scope for PR3 + // and fall through to the default handling below. + if (type instanceof IntersectionTypeClassNode + && (expression instanceof LambdaExpression + || expression instanceof MethodReferenceExpression)) { + expression.visit(this); + return; + } + expression.visit(this); - ClassNode type = castExpression.getType(); if (isObjectType(type)) return; maybeInnerClassEntry(type); diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionalInterfaceWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionalInterfaceWriter.java index 7d437bf566..76eac91cb9 100644 --- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionalInterfaceWriter.java +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionalInterfaceWriter.java @@ -78,11 +78,32 @@ public interface AbstractFunctionalInterfaceWriter { final String samMethodDescriptor, final int implMethodKind, final ClassNode implClassNode, final MethodNode implMethodNode, final Parameter[] implMethodParameters, final boolean serializable) { + // GROOVY-11998: delegate to marker-aware overload with no extra interfaces + writeFunctionalInterfaceIndy(methodVisitor, samMethodName, invokedTypeDescriptor, + samMethodDescriptor, implMethodKind, implClassNode, implMethodNode, + implMethodParameters, serializable, ClassNode.EMPTY_ARRAY); + } + + /** + * Marker-aware variant for intersection-cast targets such as + * {@code (Runnable & Cloneable) () -> ...}. Markers are threaded through + * {@code LambdaMetafactory.altMetafactory} via {@code FLAG_MARKERS} so the + * generated lambda implements every component interface at runtime. + * + * @since 5.0.0 + */ + default void writeFunctionalInterfaceIndy(final MethodVisitor methodVisitor, + final String samMethodName, final String invokedTypeDescriptor, + final String samMethodDescriptor, final int implMethodKind, + final ClassNode implClassNode, final MethodNode implMethodNode, + final Parameter[] implMethodParameters, final boolean serializable, + final ClassNode[] markers) { + boolean useAlt = serializable || (markers != null && markers.length > 0); methodVisitor.visitInvokeDynamicInsn( samMethodName, invokedTypeDescriptor, - createBootstrapMethod(serializable), - createBootstrapMethodArguments(samMethodDescriptor, implMethodKind, implClassNode, implMethodNode, implMethodParameters, serializable) + createBootstrapMethod(useAlt), + createBootstrapMethodArguments(samMethodDescriptor, implMethodKind, implClassNode, implMethodNode, implMethodParameters, serializable, markers) ); if (serializable) { methodVisitor.visitTypeInsn(CHECKCAST, BytecodeHelper.getClassInternalName(ClassHelper.SERIALIZABLE_TYPE)); @@ -188,12 +209,12 @@ public interface AbstractFunctionalInterfaceWriter { return new Parameter[] { new Parameter(SERIALIZEDLAMBDA_TYPE, "serializedLambda") }; } - private Handle createBootstrapMethod(final boolean serializable) { + private Handle createBootstrapMethod(final boolean useAlt) { return new Handle( Opcodes.H_INVOKESTATIC, "java/lang/invoke/LambdaMetafactory", - serializable ? "altMetafactory" : "metafactory", - serializable ? "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;" + useAlt ? "altMetafactory" : "metafactory", + useAlt ? "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;" : "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;", false // GROOVY-8299, GROOVY-8989, GROOVY-11265 ); @@ -202,20 +223,64 @@ public interface AbstractFunctionalInterfaceWriter { private Object[] createBootstrapMethodArguments(final String samMethodDescriptor, final int implMethodKind, final ClassNode implClassNode, final MethodNode implMethodNode, final Parameter[] implMethodParameters, final boolean serializable) { - Object[] arguments = !serializable ? new Object[3] : new Object[]{null, null, null, 5, 0}; + return createBootstrapMethodArguments(samMethodDescriptor, implMethodKind, implClassNode, implMethodNode, + implMethodParameters, serializable, ClassNode.EMPTY_ARRAY); + } - arguments[0] = Type.getMethodType(samMethodDescriptor); + /** + * GROOVY-11998: builds the variadic args for {@code LambdaMetafactory.altMetafactory} + * including {@code FLAG_MARKERS} when the cast target is an intersection. + * Layout per the JDK contract: + * <pre> + * samMethodType, implMethod, instantiatedMethodType, + * flags, + * [markerCount, marker_1, ..., marker_n] // when FLAG_MARKERS set + * [bridgeCount, bridge_1, ..., bridge_n] // when FLAG_BRIDGES set (unused) + * </pre> + */ + private Object[] createBootstrapMethodArguments(final String samMethodDescriptor, final int implMethodKind, + final ClassNode implClassNode, final MethodNode implMethodNode, + final Parameter[] implMethodParameters, final boolean serializable, + final ClassNode[] markers) { + ClassNode[] effectiveMarkers = markers == null ? ClassNode.EMPTY_ARRAY : markers; + boolean hasMarkers = effectiveMarkers.length > 0; + boolean useAlt = serializable || hasMarkers; + + if (!useAlt) { + Object[] args = new Object[3]; + args[0] = Type.getMethodType(samMethodDescriptor); + args[1] = new Handle( + implMethodKind, + getClassInternalName(implClassNode.getName()), + implMethodNode.getName(), + getMethodDescriptor(implMethodNode), + implClassNode.isInterface()); + args[2] = createInstantiatedMethodType(samMethodDescriptor, implMethodNode, implMethodParameters); + return args; + } - arguments[1] = new Handle( + int flags = (serializable ? 1 : 0) | (hasMarkers ? 2 : 0); // FLAG_SERIALIZABLE | FLAG_MARKERS + int size = 4 + (hasMarkers ? 1 + effectiveMarkers.length : 0); + Object[] args = new Object[size]; + + args[0] = Type.getMethodType(samMethodDescriptor); + args[1] = new Handle( implMethodKind, // H_INVOKESTATIC or H_INVOKEVIRTUAL or H_INVOKEINTERFACE (GROOVY-9853) getClassInternalName(implClassNode.getName()), implMethodNode.getName(), getMethodDescriptor(implMethodNode), implClassNode.isInterface()); - - arguments[2] = createInstantiatedMethodType(samMethodDescriptor, implMethodNode, implMethodParameters); - - return arguments; + args[2] = createInstantiatedMethodType(samMethodDescriptor, implMethodNode, implMethodParameters); + args[3] = flags; + + if (hasMarkers) { + int p = 4; + args[p++] = effectiveMarkers.length; + for (ClassNode m : effectiveMarkers) { + args[p++] = Type.getObjectType(getClassInternalName(m)); + } + } + return args; } private Type createInstantiatedMethodType(final String samMethodDescriptor, final MethodNode implMethodNode, final Parameter[] implMethodParameters) { diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java index 3b393b41c9..22899971f0 100644 --- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java @@ -56,6 +56,7 @@ import static org.codehaus.groovy.classgen.asm.sc.StaticTypesFunctionalInterface import static org.codehaus.groovy.classgen.asm.sc.StaticTypesFunctionalInterfaceMetadataKey.LAMBDA_PRELOADED_RECEIVER; import static org.codehaus.groovy.classgen.asm.sc.StaticTypesFunctionalInterfaceMetadataKey.LAMBDA_SHARED_VARIABLES; import static org.codehaus.groovy.transform.stc.StaticTypesMarker.CLOSURE_ARGUMENTS; +import static org.codehaus.groovy.transform.stc.StaticTypesMarker.LAMBDA_MARKERS; import static org.codehaus.groovy.transform.stc.StaticTypesMarker.PARAMETER_TYPE; import static org.objectweb.asm.Opcodes.ACC_FINAL; import static org.objectweb.asm.Opcodes.ACC_PRIVATE; @@ -94,6 +95,10 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun } boolean serializable = makeSerializableIfNeeded(expression, functionalType); + // GROOVY-11998: gather intersection-cast marker interfaces, filtering out + // Serializable (already conveyed via FLAG_SERIALIZABLE) and any interface + // already implemented by the SAM target. + ClassNode[] markers = collectLambdaMarkers(expression, functionalType); GeneratedLambda generatedLambda = getOrAddGeneratedLambda(expression, abstractMethod); ensureDeserializeLambdaSupport(expression, functionalType, abstractMethod, generatedLambda, serializable); @@ -101,7 +106,22 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun loadLambdaReceiver(generatedLambda); } - writeLambdaFactoryInvocation(functionalType.redirect(), abstractMethod, generatedLambda, serializable); + writeLambdaFactoryInvocation(functionalType.redirect(), abstractMethod, generatedLambda, serializable, markers); + } + + @SuppressWarnings("unchecked") + private static ClassNode[] collectLambdaMarkers(final LambdaExpression expression, final ClassNode functionalType) { + Object md = expression.getNodeMetaData(LAMBDA_MARKERS); + if (!(md instanceof java.util.List)) return ClassNode.EMPTY_ARRAY; + java.util.List<ClassNode> raw = (java.util.List<ClassNode>) md; + java.util.List<ClassNode> out = new java.util.ArrayList<>(raw.size()); + for (ClassNode m : raw) { + if (m == null || !m.isInterface()) continue; + if (m.equals(SERIALIZABLE_TYPE) || SERIALIZABLE_TYPE.equals(m.redirect())) continue; + if (functionalType != null && functionalType.implementsInterface(m)) continue; + out.add(m); + } + return out.toArray(ClassNode.EMPTY_ARRAY); } private static MethodNode resolveFunctionalInterfaceMethod(final ClassNode functionalType) { @@ -138,7 +158,7 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun ), helperMethod); } - private void writeLambdaFactoryInvocation(final ClassNode functionalType, final MethodNode abstractMethod, final GeneratedLambda generatedLambda, final boolean serializable) { + private void writeLambdaFactoryInvocation(final ClassNode functionalType, final MethodNode abstractMethod, final GeneratedLambda generatedLambda, final boolean serializable, final ClassNode[] markers) { writeFunctionalInterfaceIndy( controller.getMethodVisitor(), abstractMethod.getName(), @@ -148,7 +168,8 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun generatedLambda.lambdaClass, generatedLambda.lambdaMethod, generatedLambda.lambdaMethod.getParameters(), - serializable + serializable, + markers ); if (generatedLambda.nonCapturing()) { diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java index 0ee559c9ad..d6f0ad21aa 100644 --- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java @@ -122,7 +122,8 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE invocationReadyMethodReference.implementationMethod().getDeclaringClass(), invocationReadyMethodReference.implementationMethod(), functionalInterface.parametersWithExactType(), - functionalInterface.serializable() + functionalInterface.serializable(), + functionalInterface.markers() ); updateOperandStack(functionalInterface.functionalType(), invocation.capturing()); @@ -137,15 +138,36 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE if (abstractMethod == null) return null; ClassNode[] inferredParameterTypes = methodReferenceExpression.getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS); + // GROOVY-11998: pick up intersection-cast markers populated by STC + @SuppressWarnings("unchecked") + java.util.List<ClassNode> rawMarkers = (java.util.List<ClassNode>) methodReferenceExpression.getNodeMetaData(StaticTypesMarker.LAMBDA_MARKERS); + boolean fromIntersection = rawMarkers != null && rawMarkers.stream().anyMatch(m -> + m != null && (m.equals(ClassHelper.SERIALIZABLE_TYPE) || m.implementsInterface(ClassHelper.SERIALIZABLE_TYPE))); + boolean serializable = functionalType.implementsInterface(ClassHelper.SERIALIZABLE_TYPE) || fromIntersection; + ClassNode[] markers = filterMarkers(rawMarkers, functionalType); return new FunctionalInterfaceContext( functionalType, abstractMethod, createParametersWithExactType(abstractMethod, inferredParameterTypes), createMethodDescriptor(abstractMethod), - functionalType.implementsInterface(ClassHelper.SERIALIZABLE_TYPE) + serializable, + markers ); } + private static ClassNode[] filterMarkers(final java.util.List<ClassNode> raw, final ClassNode functionalType) { + if (raw == null || raw.isEmpty()) return ClassNode.EMPTY_ARRAY; + java.util.List<ClassNode> out = new java.util.ArrayList<>(raw.size()); + for (ClassNode m : raw) { + if (m == null || !m.isInterface()) continue; + if (m.equals(ClassHelper.SERIALIZABLE_TYPE) + || ClassHelper.SERIALIZABLE_TYPE.equals(m.redirect())) continue; + if (functionalType != null && functionalType.implementsInterface(m)) continue; + out.add(m); + } + return out.toArray(ClassNode.EMPTY_ARRAY); + } + private MethodReferenceTarget resolveMethodReferenceTarget(final MethodReferenceExpression methodReferenceExpression) { Expression typeOrTargetRef = methodReferenceExpression.getExpression(); boolean classExpression = typeOrTargetRef instanceof ClassExpression; @@ -684,7 +706,7 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE */ private record FunctionalInterfaceContext(ClassNode functionalType, MethodNode abstractMethod, Parameter[] parametersWithExactType, String samMethodDescriptor, - boolean serializable) { + boolean serializable, ClassNode[] markers) { } /** diff --git a/src/test/groovy/groovy/lang/IntersectionCastE2ETest.groovy b/src/test/groovy/groovy/lang/IntersectionCastE2ETest.groovy new file mode 100644 index 0000000000..5887ad76e0 --- /dev/null +++ b/src/test/groovy/groovy/lang/IntersectionCastE2ETest.groovy @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package groovy.lang + +import org.junit.jupiter.api.Test + +/** + * End-to-end tests for intersection-cast lambdas and method references + * (GROOVY-11998 PR3). + * + * Verifies that {@code (R & S) lambda} and {@code (R & S) Type::method} produce + * runtime instances that: + * <ul> + * <li>Implement every component interface ({@code instanceof} succeeds)</li> + * <li>Are serialisable when the intersection contains + * {@link java.io.Serializable}</li> + * <li>Invoke the SAM correctly</li> + * </ul> + */ +final class IntersectionCastE2ETest { + + @Test + void 'static lambda cast to (Runnable & Serializable) implements both and runs'() { + def shell = new GroovyShell() + shell.evaluate(''' + import groovy.transform.CompileStatic + import java.io.Serializable + + @CompileStatic + class T { + static Runnable make() { + return (Runnable & Serializable) () -> {} + } + } + + def r = T.make() + assert r instanceof Runnable + assert r instanceof Serializable + r.run() // does not throw + ''') + } + + @Test + void 'static intersection lambda is serialisable round-trip'() { + def shell = new GroovyShell() + shell.evaluate(''' + import groovy.transform.CompileStatic + import java.io.Serializable + import java.io.ByteArrayOutputStream + import java.io.ByteArrayInputStream + import java.io.ObjectOutputStream + import java.io.ObjectInputStream + + @CompileStatic + class T { + static Runnable make() { + return (Runnable & Serializable) () -> { System.out.println("hi") } + } + } + + def r = T.make() + def baos = new ByteArrayOutputStream() + new ObjectOutputStream(baos).withCloseable { it.writeObject(r) } + def bais = new ByteArrayInputStream(baos.toByteArray()) + def restored = null + new ObjectInputStream(bais).withCloseable { restored = it.readObject() } + assert restored instanceof Runnable + assert restored instanceof Serializable + restored.run() // round-trip executes + ''') + } + + @Test + void 'static lambda cast to (Runnable & Cloneable) implements Cloneable marker'() { + def shell = new GroovyShell() + shell.evaluate(''' + import groovy.transform.CompileStatic + + @CompileStatic + class T { + static Runnable make() { + return (Runnable & Cloneable) () -> {} + } + } + + def r = T.make() + assert r instanceof Runnable + assert r instanceof Cloneable + ''') + } + + @Test + void 'static method reference cast to (Function & Serializable) is serialisable'() { + def shell = new GroovyShell() + shell.evaluate(''' + import groovy.transform.CompileStatic + import java.io.Serializable + import java.util.function.Function + import java.io.ByteArrayOutputStream + import java.io.ByteArrayInputStream + import java.io.ObjectOutputStream + import java.io.ObjectInputStream + + @CompileStatic + class T { + static Function<String, Integer> make() { + return (Function<String, Integer> & Serializable) String::length + } + } + + Function<String, Integer> f = T.make() + assert f instanceof Function + assert f instanceof Serializable + assert f.apply("hello") == 5 + + def baos = new ByteArrayOutputStream() + new ObjectOutputStream(baos).withCloseable { it.writeObject(f) } + def bais = new ByteArrayInputStream(baos.toByteArray()) + Function<String, Integer> restored = null + new ObjectInputStream(bais).withCloseable { restored = (Function<String, Integer>) it.readObject() } + assert restored.apply("world") == 5 + ''') + } + + @Test + void 'static lambda with capturing variable cast to intersection works'() { + def shell = new GroovyShell() + shell.evaluate(''' + import groovy.transform.CompileStatic + import java.io.Serializable + import java.util.function.Supplier + + @CompileStatic + class T { + static Supplier<String> make(String captured) { + return (Supplier<String> & Serializable) () -> captured + } + } + + def s = T.make("captured-value") + assert s instanceof Supplier + assert s instanceof Serializable + assert s.get() == "captured-value" + ''') + } + + @Test + void 'intersection lambda with three components includes all markers'() { + def shell = new GroovyShell() + shell.evaluate(''' + import groovy.transform.CompileStatic + import java.io.Serializable + + @CompileStatic + class T { + static Runnable make() { + return (Runnable & Serializable & Cloneable) () -> {} + } + } + + def r = T.make() + assert r instanceof Runnable + assert r instanceof Serializable + assert r instanceof Cloneable + ''') + } +} diff --git a/src/test/groovy/groovy/transform/stc/IntersectionCastSTCTest.groovy b/src/test/groovy/groovy/transform/stc/IntersectionCastSTCTest.groovy new file mode 100644 index 0000000000..ba9186c10c --- /dev/null +++ b/src/test/groovy/groovy/transform/stc/IntersectionCastSTCTest.groovy @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package groovy.transform.stc + +import groovy.transform.TypeChecked +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.IntersectionTypeClassNode +import org.codehaus.groovy.ast.expr.CastExpression +import org.codehaus.groovy.ast.expr.LambdaExpression +import org.codehaus.groovy.ast.tools.GenericsUtils +import org.codehaus.groovy.control.CompilationUnit +import org.codehaus.groovy.control.CompilerConfiguration +import org.codehaus.groovy.control.MultipleCompilationErrorsException +import org.codehaus.groovy.control.Phases +import org.codehaus.groovy.control.SourceUnit +import org.codehaus.groovy.control.customizers.ASTTransformationCustomizer +import org.codehaus.groovy.control.customizers.ImportCustomizer +import org.codehaus.groovy.control.messages.SyntaxErrorMessage +import org.codehaus.groovy.transform.stc.StaticTypesMarker +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test + +/** + * Tests for static type checking of intersection-cast targets (GROOVY-11998 PR2). + * + * These tests stop compilation at {@link Phases#SEMANTIC_ANALYSIS} so they + * exercise resolution and STC without relying on bytecode generation, which + * is delivered in subsequent phases. + */ +final class IntersectionCastSTCTest { + + @Test + void 'STC accepts (Runnable & Serializable) lambda'() { + def cu = compileToSemantic(''' + @groovy.transform.TypeChecked + class T { + static def make() { + return (Runnable & java.io.Serializable) () -> { System.out.println("hi") } + } + } + ''') + assertNoErrors(cu) + } + + @Test + void 'STC accepts (Runnable & Serializable) closure'() { + def cu = compileToSemantic(''' + @groovy.transform.TypeChecked + class T { + static def make() { + return (Runnable & java.io.Serializable) { -> System.out.println("hi") } + } + } + ''') + assertNoErrors(cu) + } + + @Test + void 'STC sets PRIMARY_FUNCTIONAL_TYPE and LAMBDA_MARKERS on intersection-cast lambda'() { + def cu = compileToSemantic(''' + @groovy.transform.TypeChecked + class T { + static def make() { + return (Runnable & java.io.Serializable) () -> { System.out.println("hi") } + } + } + ''') + assertNoErrors(cu) + CastExpression cast = findFirstIntersectionCast(cu) + assert cast != null + + ClassNode primary = (ClassNode) cast.getNodeMetaData(StaticTypesMarker.PRIMARY_FUNCTIONAL_TYPE) + assert primary != null + assert primary.name == 'java.lang.Runnable' + + List<ClassNode> markers = (List<ClassNode>) cast.getNodeMetaData(StaticTypesMarker.LAMBDA_MARKERS) + assert markers != null + assert markers.size() == 1 + assert markers[0].name == 'java.io.Serializable' + } + + @Test + void 'lambda is marked Serializable when intersection includes Serializable'() { + def cu = compileToSemantic(''' + @groovy.transform.TypeChecked + class T { + static def make() { + return (Runnable & java.io.Serializable) () -> { System.out.println("hi") } + } + } + ''') + assertNoErrors(cu) + CastExpression cast = findFirstIntersectionCast(cu) + assert cast != null + assert cast.expression instanceof LambdaExpression + assert ((LambdaExpression) cast.expression).serializable + } + + @Test + void 'STC rejects intersection with two SAM-bearing interfaces for lambda target'() { + // Define two SAM-bearing interfaces inside a script + def errors = compileExpectingErrors(''' + interface A { void runIt() } + interface B { void doIt() } + @groovy.transform.TypeChecked + class T { + static def make() { + return (A & B) () -> {} + } + } + ''') + assert errors.any { it.contains('multiple functional interface components') } + } + + @Test + void 'STC rejects intersection with no functional interface for lambda target'() { + def errors = compileExpectingErrors(''' + @groovy.transform.TypeChecked + class T { + static def make() { + return (java.io.Serializable & Cloneable) () -> {} + } + } + ''') + assert errors.any { it.contains('no functional interface component') } + } + + @Test + void 'STC rejects intersection where the class component is not first'() { + def errors = compileExpectingErrors(''' + class C {} + @groovy.transform.TypeChecked + class T { + static def make(value) { + return (Runnable & C) value + } + } + ''') + assert errors.any { it.contains('Class component of intersection type must come first') } + } + + @Test + void 'STC rejects intersection with a final class component'() { + def errors = compileExpectingErrors(''' + @groovy.transform.TypeChecked + class T { + static def make(value) { + return (String & Runnable) value + } + } + ''') + assert errors.any { it.contains('may not include the final class') } + } + + @Test + void 'resolver resolves all components and reclassifies'() { + def cu = compileToSemantic(''' + @groovy.transform.TypeChecked + class T { + static def make() { + return (Runnable & java.io.Serializable) () -> { } + } + } + ''') + assertNoErrors(cu) + CastExpression cast = findFirstIntersectionCast(cu) + assert cast != null + IntersectionTypeClassNode it = (IntersectionTypeClassNode) cast.type + ClassNode[] components = it.components + assert components.length == 2 + assert components.every { it.isResolved() || !it.isPrimaryClassNode() } + // After resolution + reclassification, both components are interfaces, so superClass is Object + assert it.superClass.name == 'java.lang.Object' + assert it.interfaces*.name as Set == ['java.lang.Runnable', 'java.io.Serializable'] as Set + } + + //-------------------------------------------------------------------------- + + private static CompilationUnit compileToSemantic(String src) { + CompilerConfiguration config = new CompilerConfiguration() + ImportCustomizer imports = new ImportCustomizer() + config.addCompilationCustomizers(imports) + + CompilationUnit cu = new CompilationUnit(config, null, new GroovyClassLoader()) + cu.addSource('Test.groovy', src) + try { + cu.compile(Phases.INSTRUCTION_SELECTION) + } catch (MultipleCompilationErrorsException ignored) { + // tests inspect cu.errorCollector + } + return cu + } + + private static List<String> compileExpectingErrors(String src) { + CompilationUnit cu = compileToSemantic(src) + return cu.errorCollector.errors.findAll { it instanceof SyntaxErrorMessage } + .collect { ((SyntaxErrorMessage) it).cause.message } + } + + private static void assertNoErrors(CompilationUnit cu) { + if (cu.errorCollector.hasErrors()) { + String msg = cu.errorCollector.errors + .findAll { it instanceof SyntaxErrorMessage } + .collect { ((SyntaxErrorMessage) it).cause.message }.join('\n') + Assertions.fail("Compilation produced errors:\n${msg}") + } + } + + private static CastExpression findFirstIntersectionCast(CompilationUnit cu) { + CastExpression[] holder = new CastExpression[1] + cu.AST.classes.each { cn -> + cn.methods.each { mn -> + if (mn.code == null) return + mn.code.visit(new org.codehaus.groovy.ast.CodeVisitorSupport() { + @Override + void visitCastExpression(CastExpression expression) { + if (holder[0] == null && expression.type instanceof IntersectionTypeClassNode) { + holder[0] = expression + } + super.visitCastExpression(expression) + } + }) + } + } + return holder[0] + } +}
