This is an automated email from the ASF dual-hosted git repository.

paulk pushed a commit to branch groovy11884
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/groovy11884 by this push:
     new e7ae47cfe3 GROOVY-11885: groovy-contacts could support scripts
e7ae47cfe3 is described below

commit e7ae47cfe30581812346568af166341e60676f68
Author: Paul King <[email protected]>
AuthorDate: Sat Mar 28 11:29:29 2026 +1000

    GROOVY-11885: groovy-contacts could support scripts
---
 .../src/main/java/groovy/contracts/Invariant.java  |   2 +-
 ...osureExpressionEvaluationASTTransformation.java |  32 ++++++
 .../contracts/generation/CandidateChecks.java      |  17 ++-
 .../contracts/tests/inv/LoopInvariantTests.groovy  | 117 +++++++++++++++++++++
 4 files changed, 166 insertions(+), 2 deletions(-)

diff --git 
a/subprojects/groovy-contracts/src/main/java/groovy/contracts/Invariant.java 
b/subprojects/groovy-contracts/src/main/java/groovy/contracts/Invariant.java
index 55e8d905aa..bdc0f2fff1 100644
--- a/subprojects/groovy-contracts/src/main/java/groovy/contracts/Invariant.java
+++ b/subprojects/groovy-contracts/src/main/java/groovy/contracts/Invariant.java
@@ -66,7 +66,7 @@ import java.lang.annotation.Target;
 @Incubating
 @ClassInvariant
 @Repeatable(Invariants.class)
-@ExtendedTarget(ExtendedElementType.LOOP)
+@ExtendedTarget({ExtendedElementType.LOOP, ExtendedElementType.IMPORT})
 @AnnotationProcessorImplementation(ClassInvariantAnnotationProcessor.class)
 
@GroovyASTTransformationClass("org.apache.groovy.contracts.ast.LoopInvariantASTTransformation")
 public @interface Invariant {
diff --git 
a/subprojects/groovy-contracts/src/main/java/org/apache/groovy/contracts/ast/ClosureExpressionEvaluationASTTransformation.java
 
b/subprojects/groovy-contracts/src/main/java/org/apache/groovy/contracts/ast/ClosureExpressionEvaluationASTTransformation.java
index f63b679df1..720de25115 100644
--- 
a/subprojects/groovy-contracts/src/main/java/org/apache/groovy/contracts/ast/ClosureExpressionEvaluationASTTransformation.java
+++ 
b/subprojects/groovy-contracts/src/main/java/org/apache/groovy/contracts/ast/ClosureExpressionEvaluationASTTransformation.java
@@ -19,6 +19,7 @@
 package org.apache.groovy.contracts.ast;
 
 import groovy.contracts.Contracted;
+import groovy.contracts.Invariant;
 import org.apache.groovy.contracts.ast.visitor.AnnotationClosureVisitor;
 import org.apache.groovy.contracts.ast.visitor.ConfigurationSetup;
 import org.apache.groovy.contracts.ast.visitor.ContractElementVisitor;
@@ -27,6 +28,7 @@ import org.codehaus.groovy.ast.ASTNode;
 import org.codehaus.groovy.ast.AnnotationNode;
 import org.codehaus.groovy.ast.ClassHelper;
 import org.codehaus.groovy.ast.ClassNode;
+import org.codehaus.groovy.ast.ImportNode;
 import org.codehaus.groovy.ast.ModuleNode;
 import org.codehaus.groovy.control.CompilePhase;
 import org.codehaus.groovy.control.SourceUnit;
@@ -50,12 +52,42 @@ public class ClosureExpressionEvaluationASTTransformation 
extends BaseASTTransfo
     public void visit(ASTNode[] nodes, SourceUnit unit) {
         final ModuleNode moduleNode = unit.getAST();
 
+        promoteImportInvariantsToScriptClass(moduleNode);
+
         ReaderSource source = getReaderSource(unit);
         final List<ClassNode> classNodes = new 
ArrayList<>(moduleNode.getClasses());
 
         generateAnnotationClosureClasses(unit, source, classNodes);
     }
 
+    /**
+     * Promotes {@link Invariant} annotations found on import statements to 
the script body class,
+     * allowing {@code @Invariant} on an import to act as a class invariant 
for the script.
+     */
+    private static void promoteImportInvariantsToScriptClass(final ModuleNode 
moduleNode) {
+        final String invariantName = Invariant.class.getName();
+        final List<ImportNode> allImports = new ArrayList<>();
+        allImports.addAll(moduleNode.getImports());
+        allImports.addAll(moduleNode.getStarImports());
+        allImports.addAll(moduleNode.getStaticImports().values());
+        allImports.addAll(moduleNode.getStaticStarImports().values());
+
+        ClassNode scriptClass = null;
+        for (ImportNode importNode : allImports) {
+            for (AnnotationNode annotation : importNode.getAnnotations()) {
+                if (invariantName.equals(annotation.getClassNode().getName())) 
{
+                    if (scriptClass == null) {
+                        scriptClass = moduleNode.getClasses().stream()
+                                .filter(ClassNode::isScriptBody)
+                                .findFirst().orElse(null);
+                        if (scriptClass == null) return;
+                    }
+                    scriptClass.addAnnotation(annotation);
+                }
+            }
+        }
+    }
+
     private void generateAnnotationClosureClasses(SourceUnit unit, 
ReaderSource source, List<ClassNode> classNodes) {
         final AnnotationClosureVisitor annotationClosureVisitor = new 
AnnotationClosureVisitor(unit, source);
 
diff --git 
a/subprojects/groovy-contracts/src/main/java/org/apache/groovy/contracts/generation/CandidateChecks.java
 
b/subprojects/groovy-contracts/src/main/java/org/apache/groovy/contracts/generation/CandidateChecks.java
index 7c2b541626..50ae38fca6 100644
--- 
a/subprojects/groovy-contracts/src/main/java/org/apache/groovy/contracts/generation/CandidateChecks.java
+++ 
b/subprojects/groovy-contracts/src/main/java/org/apache/groovy/contracts/generation/CandidateChecks.java
@@ -18,6 +18,9 @@
  */
 package org.apache.groovy.contracts.generation;
 
+import groovy.contracts.Invariant;
+import org.codehaus.groovy.ast.AnnotationNode;
+import org.codehaus.groovy.ast.ClassHelper;
 import org.codehaus.groovy.ast.ClassNode;
 import org.codehaus.groovy.ast.MethodNode;
 import org.codehaus.groovy.ast.PropertyNode;
@@ -41,7 +44,19 @@ public class CandidateChecks {
      * @return whether the given <tt>type</tt> is a candidate for applying 
contract assertions
      */
     public static boolean isContractsCandidate(final ClassNode type) {
-        return type != null && !type.isSynthetic() && !type.isInterface() && 
!type.isEnum() && !type.isGenericsPlaceHolder() && !type.isScript() && 
!type.isScriptBody() && !isRuntimeClass(type);
+        if (type == null || type.isSynthetic() || type.isInterface() || 
type.isEnum() || type.isGenericsPlaceHolder() || isRuntimeClass(type)) return 
false;
+        if ((type.isScript() || type.isScriptBody()) && 
!hasContractAnnotations(type)) return false;
+        return true;
+    }
+
+    private static boolean hasContractAnnotations(final ClassNode type) {
+        if 
(!type.getAnnotations(ClassHelper.makeWithoutCaching(Invariant.class)).isEmpty())
 return true;
+        for (MethodNode method : type.getMethods()) {
+            for (AnnotationNode annotation : method.getAnnotations()) {
+                if 
(annotation.getClassNode().getName().startsWith("groovy.contracts.")) return 
true;
+            }
+        }
+        return false;
     }
 
     /**
diff --git 
a/subprojects/groovy-contracts/src/test/groovy/org/apache/groovy/contracts/tests/inv/LoopInvariantTests.groovy
 
b/subprojects/groovy-contracts/src/test/groovy/org/apache/groovy/contracts/tests/inv/LoopInvariantTests.groovy
index 3ac1f77565..9e13b9d60f 100644
--- 
a/subprojects/groovy-contracts/src/test/groovy/org/apache/groovy/contracts/tests/inv/LoopInvariantTests.groovy
+++ 
b/subprojects/groovy-contracts/src/test/groovy/org/apache/groovy/contracts/tests/inv/LoopInvariantTests.groovy
@@ -155,5 +155,122 @@ class LoopInvariantTests extends BaseTestClass {
             assert f.property == 'hello'
         '''
     }
+
+    @Test
+    void invariantOnImportActsAsScriptClassInvariant() {
+        assertScript '''
+            @groovy.contracts.Invariant({ property != null })
+            import groovy.transform.Field
+
+            @Field String property = 'hello'
+            assert property == 'hello'
+        '''
+    }
+
+    @Test
+    void invariantOnImportViolationThrows() {
+        shouldFail AssertionError, '''
+            @groovy.contracts.Invariant({ property != null })
+            import groovy.transform.Field
+
+            @Field String property = 'hello'
+
+            def nullify() {
+                property = null
+            }
+
+            nullify()
+        '''
+    }
+
+    @Test
+    void invariantOnImportBankAccountScript() {
+        assertScript '''
+            @Invariant({ balance >= 0 })
+            import groovy.transform.Field
+            import groovy.contracts.Invariant
+            import static groovy.test.GroovyAssert.shouldFail
+            import org.apache.groovy.contracts.ClassInvariantViolation
+
+            @Field Integer balance = 5
+
+            def withdraw(int amount) { balance -= amount }
+
+            def deposit(int amount) { balance += amount }
+
+            deposit(5)
+            assert balance == 10
+
+            shouldFail(ClassInvariantViolation) {
+                withdraw(15)
+            }
+
+            balance = 10  // restore valid state (withdraw left balance at -5)
+
+            shouldFail(ClassInvariantViolation) {
+                deposit(-15)
+            }
+
+            balance = 10  // restore valid state before run() ends
+        '''
+    }
+
+    @Test
+    void ensuresOnScriptMethods() {
+        assertScript '''
+            import groovy.transform.Field
+            import groovy.contracts.*
+            import static groovy.test.GroovyAssert.shouldFail
+
+            @Field Integer balance = 5
+
+            @Ensures({ balance >= 0 })
+            def withdraw(int amount) { balance -= amount }
+
+            @Ensures({ balance >= 0 })
+            def deposit(int amount) { balance += amount }
+
+            deposit(5)
+            assert balance == 10
+
+            shouldFail(AssertionError) {
+                withdraw(15)
+            }
+
+            balance = 5
+
+            shouldFail(AssertionError) {
+                deposit(-10)
+            }
+        '''
+    }
+
+    @Test
+    void requiresOnScriptMethods() {
+        assertScript '''
+            import groovy.transform.Field
+            import groovy.contracts.*
+            import static groovy.test.GroovyAssert.shouldFail
+
+            @Field Integer balance = 5
+
+            @Requires({ balance >= amount })
+            def withdraw(int amount) { balance -= amount }
+
+            @Requires({ amount >= 0 })
+            def deposit(int amount) { balance += amount }
+
+            deposit(5)
+            assert balance == 10
+
+            shouldFail(AssertionError) {
+                withdraw(15)
+            }
+
+            shouldFail(AssertionError) {
+                deposit(-15)
+            }
+        '''
+    }
 }
 

Reply via email to