http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/tools/GrapeMain.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/tools/GrapeMain.groovy b/src/main/groovy/org/codehaus/groovy/tools/GrapeMain.groovy new file mode 100644 index 0000000..2f53afa --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/tools/GrapeMain.groovy @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.tools + +import groovy.grape.Grape +import groovy.transform.Field +import org.apache.commons.cli.CommandLine +import org.apache.commons.cli.DefaultParser +import org.apache.commons.cli.HelpFormatter +import org.apache.commons.cli.Option +import org.apache.commons.cli.OptionGroup +import org.apache.commons.cli.Options +import org.apache.ivy.util.DefaultMessageLogger +import org.apache.ivy.util.Message + +//commands + +@Field install = {arg, cmd -> + if (arg.size() > 5 || arg.size() < 3) { + println 'install requires two to four arguments: <group> <module> [<version> [<classifier>]]' + return + } + def ver = '*' + if (arg.size() >= 4) { + ver = arg[3] + } + def classifier = null + if (arg.size() >= 5) { + classifier = arg[4] + } + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging() + + cmd.getOptionValues('r')?.each { String url -> + Grape.addResolver(name:url, root:url) + } + + try { + Grape.grab(autoDownload: true, group: arg[1], module: arg[2], version: ver, classifier: classifier, noExceptions: true) + } catch (Exception e) { + println "An error occured : $ex" + } +} + +@Field uninstall = {arg, cmd -> + if (arg.size() != 4) { + println 'uninstall requires three arguments: <group> <module> <version>' + // TODO make version optional? support classifier? +// println 'uninstall requires two to four arguments, <group> <module> [<version>] [<classifier>]' + return + } + String group = arg[1] + String module = arg[2] + String ver = arg[3] +// def classifier = null + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging() + + if (!Grape.enumerateGrapes().find {String groupName, Map g -> + g.any {String moduleName, List<String> versions -> + group == groupName && module == moduleName && ver in versions + } + }) { + println "uninstall did not find grape matching: $group $module $ver" + def fuzzyMatches = Grape.enumerateGrapes().findAll { String groupName, Map g -> + g.any {String moduleName, List<String> versions -> + groupName.contains(group) || moduleName.contains(module) || + group.contains(groupName) || module.contains(moduleName) + } + } + if (fuzzyMatches) { + println 'possible matches:' + fuzzyMatches.each { String groupName, Map g -> println " $groupName: $g" } + } + return + } + Grape.instance.uninstallArtifact(group, module, ver) +} + +@Field list = {arg, cmd -> + println "" + + int moduleCount = 0 + int versionCount = 0 + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging() + + Grape.enumerateGrapes().each {String groupName, Map group -> + group.each {String moduleName, List<String> versions -> + println "$groupName $moduleName $versions" + moduleCount++ + versionCount += versions.size() + } + } + println "" + println "$moduleCount Grape modules cached" + println "$versionCount Grape module versions cached" +} + +@Field resolve = {arg, cmd -> + Options options = new Options(); + options.addOption(Option.builder("a").hasArg(false).longOpt("ant").build()); + options.addOption(Option.builder("d").hasArg(false).longOpt("dos").build()); + options.addOption(Option.builder("s").hasArg(false).longOpt("shell").build()); + options.addOption(Option.builder("i").hasArg(false).longOpt("ivy").build()); + CommandLine cmd2 = new DefaultParser().parse(options, arg[1..-1] as String[], true); + arg = cmd2.args + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging(Message.MSG_ERR) + + if ((arg.size() % 3) != 0) { + println 'There needs to be a multiple of three arguments: (group module version)+' + return + } + if (args.size() < 3) { + println 'At least one Grape reference is required' + return + } + def before, between, after + def ivyFormatRequested = false + + if (cmd2.hasOption('a')) { + before = '<pathelement location="' + between = '">\n<pathelement location="' + after = '">' + } else if (cmd2.hasOption('d')) { + before = 'set CLASSPATH=' + between = ';' + after = '' + } else if (cmd2.hasOption('s')) { + before = 'export CLASSPATH=' + between = ':' + after = '' + } else if (cmd2.hasOption('i')) { + ivyFormatRequested = true + before = '<dependency ' + between = '">\n<dependency ' + after = '">' + } else { + before = '' + between = '\n' + after = '\n' + } + + iter = arg.iterator() + def params = [[:]] + def depsInfo = [] // this list will contain the module/group/version info of all resolved dependencies + if(ivyFormatRequested) { + params << depsInfo + } + while (iter.hasNext()) { + params.add([group: iter.next(), module: iter.next(), version: iter.next()]) + } + try { + def results = [] + def uris = Grape.resolve(* params) + if(!ivyFormatRequested) { + for (URI uri: uris) { + if (uri.scheme == 'file') { + results += new File(uri).path + } else { + results += uri.toASCIIString() + } + } + } else { + depsInfo.each { dep -> + results += ('org="' + dep.group + '" name="' + dep.module + '" revision="' + dep.revision) + } + } + + if (results) { + println "${before}${results.join(between)}${after}" + } else { + println 'Nothing was resolved' + } + } catch (Exception e) { + println "Error in resolve:\n\t$e.message" + if (e.message =~ /unresolved dependency/) println "Perhaps the grape is not installed?" + } +} + +@Field help = { arg, cmd -> grapeHelp() } + +@Field commands = [ + 'install': [closure: install, + shortHelp: 'Installs a particular grape'], + 'uninstall': [closure: uninstall, + shortHelp: 'Uninstalls a particular grape (non-transitively removes the respective jar file from the grape cache)'], + 'list': [closure: list, + shortHelp: 'Lists all installed grapes'], + 'resolve': [closure: resolve, + shortHelp: 'Enumerates the jars used by a grape'], + 'help': [closure: help, + shortHelp: 'Usage information'] +] + +@Field grapeHelp = { + int spacesLen = commands.keySet().max {it.length()}.length() + 3 + String spaces = ' ' * spacesLen + + PrintWriter pw = new PrintWriter(binding.variables.out ?: System.out) + new HelpFormatter().printHelp( + pw, + 80, + "grape [options] <command> [args]\n", + "options:", + options, + 2, + 4, + null, // footer + true); + pw.flush() + + println "" + println "commands:" + commands.each {String k, v -> + println " ${(k + spaces).substring(0, spacesLen)} $v.shortHelp" + } + println "" +} + +@Field setupLogging = {int defaultLevel = 2 -> // = Message.MSG_INFO -> some parsing error :( + if (cmd.hasOption('q')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_ERR)) + } else if (cmd.hasOption('w')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_WARN)) + } else if (cmd.hasOption('i')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_INFO)) + } else if (cmd.hasOption('V')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_VERBOSE)) + } else if (cmd.hasOption('d')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_DEBUG)) + } else { + Message.setDefaultLogger(new DefaultMessageLogger(defaultLevel)) + } +} + +// command line parsing +@Field Options options = new Options(); + +options.addOption(Option.builder("D").longOpt("define").desc("define a system property").numberOfArgs(2).valueSeparator().argName("name=value").build()); +options.addOption(Option.builder("r").longOpt("resolver").desc("define a grab resolver (for install)").hasArg(true).argName("url").build()); +options.addOption(Option.builder("h").hasArg(false).desc("usage information").longOpt("help").build()); + +// Logging Level Options +options.addOptionGroup( + new OptionGroup() + .addOption(Option.builder("q").hasArg(false).desc("Log level 0 - only errors").longOpt("quiet").build()) + .addOption(Option.builder("w").hasArg(false).desc("Log level 1 - errors and warnings").longOpt("warn").build()) + .addOption(Option.builder("i").hasArg(false).desc("Log level 2 - info").longOpt("info").build()) + .addOption(Option.builder("V").hasArg(false).desc("Log level 3 - verbose").longOpt("verbose").build()) + .addOption(Option.builder("d").hasArg(false).desc("Log level 4 - debug").longOpt("debug").build()) +) +options.addOption(Option.builder("v").hasArg(false).desc("display the Groovy and JVM versions").longOpt("version").build()); + +@Field CommandLine cmd + +cmd = new DefaultParser().parse(options, args, true); + +if (cmd.hasOption('h')) { + grapeHelp() + return +} + +if (cmd.hasOption('v')) { + String version = GroovySystem.getVersion(); + println "Groovy Version: $version JVM: ${System.getProperty('java.version')}" + return +} + +if (options.hasOption('D')) { + options.getOptionProperties('D')?.each { k, v -> + System.setProperty(k, v) + } +} + +String[] arg = cmd.args +if (arg?.length == 0) { + grapeHelp() +} else if (commands.containsKey(arg[0])) { + commands[arg[0]].closure(arg, cmd) +} else { + println "grape: '${arg[0]}' is not a grape command. See 'grape --help'" +}
http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/tools/ast/TransformTestHelper.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/tools/ast/TransformTestHelper.groovy b/src/main/groovy/org/codehaus/groovy/tools/ast/TransformTestHelper.groovy new file mode 100644 index 0000000..ce98d58 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/tools/ast/TransformTestHelper.groovy @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.tools.ast + +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.classgen.GeneratorContext +import org.codehaus.groovy.control.CompilationUnit +import org.codehaus.groovy.control.CompilationUnit.PrimaryClassNodeOperation +import org.codehaus.groovy.control.CompilePhase +import org.codehaus.groovy.control.CompilerConfiguration +import org.codehaus.groovy.control.SourceUnit +import org.codehaus.groovy.transform.ASTTransformation + +import java.security.CodeSource + +/* +* This TestHarness exists so that a global transform can be run without +* using the Jar services mechanism, which requires building a jar. +* +* To use this simply create an instance of TransformTestHelper with +* an ASTTransformation and CompilePhase, then invoke parse(File) or +* parse(String). +* +* This test harness is not exactly the same as executing a global transformation +* but can greatly aide in debugging and testing a transform. You should still +* test your global transformation when packaged as a jar service before +* releasing it. +* +* @author Hamlet D'Arcy +*/ +class TransformTestHelper { + + private final ASTTransformation transform + private final CompilePhase phase + + /** + * Creates the test helper. + * @param transform + * the transform to run when compiling the file later + * @param phase + * the phase to run the transform in + */ + def TransformTestHelper(ASTTransformation transform, CompilePhase phase) { + this.transform = transform + this.phase = phase + } + + /** + * Compiles the File into a Class applying the transform specified in the constructor. + * @input input + * must be a groovy source file + */ + public Class parse(File input) { + TestHarnessClassLoader loader = new TestHarnessClassLoader(transform, phase) + return loader.parseClass(input) + } + + /** + * Compiles the String into a Class applying the transform specified in the constructor. + * @input input + * must be a valid groovy source string + */ + public Class parse(String input) { + TestHarnessClassLoader loader = new TestHarnessClassLoader(transform, phase) + return loader.parseClass(input) + } +} + +/** +* ClassLoader exists so that TestHarnessOperation can be wired into the compile. +* +* @author Hamlet D'Arcy +*/ [email protected] class TestHarnessClassLoader extends GroovyClassLoader { + + private final ASTTransformation transform + private final CompilePhase phase + + TestHarnessClassLoader(ASTTransformation transform, CompilePhase phase) { + this.transform = transform + this.phase = phase + } + + protected CompilationUnit createCompilationUnit(CompilerConfiguration config, CodeSource codeSource) { + CompilationUnit cu = super.createCompilationUnit(config, codeSource) + cu.addPhaseOperation(new TestHarnessOperation(transform), phase.getPhaseNumber()) + return cu + } +} + +/** +* Operation exists so that an AstTransformation can be run against the SourceUnit. +* +* @author Hamlet D'Arcy +*/ [email protected] class TestHarnessOperation extends PrimaryClassNodeOperation { + + private final ASTTransformation transform + + def TestHarnessOperation(transform) { + this.transform = transform; + } + + public void call(SourceUnit source, GeneratorContext context, ClassNode classNode) { + transform.visit(null, source) + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/ASTTestTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/ASTTestTransformation.groovy b/src/main/groovy/org/codehaus/groovy/transform/ASTTestTransformation.groovy new file mode 100644 index 0000000..dcfe314 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/ASTTestTransformation.groovy @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform + +import groovy.transform.CompilationUnitAware +import org.codehaus.groovy.ast.ASTNode +import org.codehaus.groovy.ast.AnnotationNode +import org.codehaus.groovy.ast.ClassCodeVisitorSupport +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.expr.PropertyExpression +import org.codehaus.groovy.ast.expr.VariableExpression +import org.codehaus.groovy.ast.stmt.Statement +import org.codehaus.groovy.control.CompilationUnit +import org.codehaus.groovy.control.CompilePhase +import org.codehaus.groovy.control.CompilerConfiguration +import org.codehaus.groovy.control.ErrorCollector +import org.codehaus.groovy.control.Janitor +import org.codehaus.groovy.control.ProcessingUnit +import org.codehaus.groovy.control.SourceUnit +import org.codehaus.groovy.control.customizers.ImportCustomizer +import org.codehaus.groovy.control.io.ReaderSource +import org.codehaus.groovy.runtime.MethodClosure +import org.codehaus.groovy.syntax.SyntaxException +import org.codehaus.groovy.tools.Utilities + +import static org.codehaus.groovy.ast.tools.GeneralUtils.classX +import static org.codehaus.groovy.ast.tools.GeneralUtils.propX + +@GroovyASTTransformation(phase = CompilePhase.SEMANTIC_ANALYSIS) +class ASTTestTransformation extends AbstractASTTransformation implements CompilationUnitAware { + private CompilationUnit compilationUnit + + void visit(final ASTNode[] nodes, final SourceUnit source) { + AnnotationNode annotationNode = nodes[0] + def member = annotationNode.getMember('phase') + def phase = null + if (member) { + if (member instanceof VariableExpression) { + phase = CompilePhase.valueOf(member.text) + } else if (member instanceof PropertyExpression) { + phase = CompilePhase.valueOf(member.propertyAsString) + } + annotationNode.setMember('phase', propX(classX(ClassHelper.make(CompilePhase)), phase.toString())) + } + member = annotationNode.getMember('value') + if (member && !(member instanceof ClosureExpression)) { + throw new SyntaxException("ASTTest value must be a closure", member.getLineNumber(), member.getColumnNumber()) + } + if (!member && !annotationNode.getNodeMetaData(ASTTestTransformation)) { + throw new SyntaxException("Missing test expression", annotationNode.getLineNumber(), annotationNode.getColumnNumber()) + } + // convert value into node metadata so that the expression doesn't mix up with other AST xforms like type checking + annotationNode.putNodeMetaData(ASTTestTransformation, member) + annotationNode.getMembers().remove('value') + + def pcallback = compilationUnit.progressCallback + def callback = new CompilationUnit.ProgressCallback() { + Binding binding = new Binding([:].withDefault {null}) + + @Override + void call(final ProcessingUnit context, final int phaseRef) { + if (phase==null || phaseRef == phase.phaseNumber) { + ClosureExpression testClosure = nodes[0].getNodeMetaData(ASTTestTransformation) + StringBuilder sb = new StringBuilder() + for (int i = testClosure.lineNumber; i <= testClosure.lastLineNumber; i++) { + sb.append(source.source.getLine(i, new Janitor())).append('\n') + } + def testSource = sb.substring(testClosure.columnNumber + 1, sb.length()) + testSource = testSource.substring(0, testSource.lastIndexOf('}')) + CompilerConfiguration config = new CompilerConfiguration() + def customizer = new ImportCustomizer() + config.addCompilationCustomizers(customizer) + binding['sourceUnit'] = source + binding['node'] = nodes[1] + binding['lookup'] = new MethodClosure(LabelFinder, "lookup").curry(nodes[1]) + binding['compilationUnit'] = compilationUnit + binding['compilePhase'] = CompilePhase.fromPhaseNumber(phaseRef) + + GroovyShell shell = new GroovyShell(binding, config) + + source.AST.imports.each { + customizer.addImport(it.alias, it.type.name) + } + source.AST.starImports.each { + customizer.addStarImports(it.packageName) + } + source.AST.staticImports.each { + customizer.addStaticImport(it.value.alias, it.value.type.name, it.value.fieldName) + } + source.AST.staticStarImports.each { + customizer.addStaticStars(it.value.className) + } + shell.evaluate(testSource) + } + } + } + + if (pcallback!=null) { + if (pcallback instanceof ProgressCallbackChain) { + pcallback.addCallback(callback) + } else { + pcallback = new ProgressCallbackChain(pcallback, callback) + } + callback = pcallback + } + + compilationUnit.setProgressCallback(callback) + + } + + void setCompilationUnit(final CompilationUnit unit) { + this.compilationUnit = unit + } + + private static class AssertionSourceDelegatingSourceUnit extends SourceUnit { + private final ReaderSource delegate + + AssertionSourceDelegatingSourceUnit(final String name, final ReaderSource source, final CompilerConfiguration flags, final GroovyClassLoader loader, final ErrorCollector er) { + super(name, '', flags, loader, er) + delegate = source + } + + @Override + String getSample(final int line, final int column, final Janitor janitor) { + String sample = null; + String text = delegate.getLine(line, janitor); + + if (text != null) { + if (column > 0) { + String marker = Utilities.repeatString(" ", column - 1) + "^"; + + if (column > 40) { + int start = column - 30 - 1; + int end = (column + 10 > text.length() ? text.length() : column + 10 - 1); + sample = " " + text.substring(start, end) + Utilities.eol() + " " + + marker.substring(start, marker.length()); + } else { + sample = " " + text + Utilities.eol() + " " + marker; + } + } else { + sample = text; + } + } + + return sample; + + } + + } + + private static class ProgressCallbackChain extends CompilationUnit.ProgressCallback { + + private final List<CompilationUnit.ProgressCallback> chain = new LinkedList<CompilationUnit.ProgressCallback>() + + ProgressCallbackChain(CompilationUnit.ProgressCallback... callbacks) { + if (callbacks!=null) { + callbacks.each { addCallback(it) } + } + } + + public void addCallback(CompilationUnit.ProgressCallback callback) { + chain << callback + } + + @Override + void call(final ProcessingUnit context, final int phase) { + chain*.call(context, phase) + } + } + + public static class LabelFinder extends ClassCodeVisitorSupport { + + public static List<Statement> lookup(MethodNode node, String label) { + LabelFinder finder = new LabelFinder(label, null) + node.code.visit(finder) + + finder.targets + } + + public static List<Statement> lookup(ClassNode node, String label) { + LabelFinder finder = new LabelFinder(label, null) + node.methods*.code*.visit(finder) + node.declaredConstructors*.code*.visit(finder) + + finder.targets + } + + private final String label + private final SourceUnit unit + + private final List<Statement> targets = new LinkedList<Statement>(); + + LabelFinder(final String label, final SourceUnit unit) { + this.label = label + this.unit = unit; + } + + @Override + protected SourceUnit getSourceUnit() { + unit + } + + @Override + protected void visitStatement(final Statement statement) { + super.visitStatement(statement) + if (statement.statementLabel==label) targets << statement + } + + List<Statement> getTargets() { + return Collections.unmodifiableList(targets) + } + } + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/ConditionalInterruptibleASTTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/ConditionalInterruptibleASTTransformation.groovy b/src/main/groovy/org/codehaus/groovy/transform/ConditionalInterruptibleASTTransformation.groovy new file mode 100644 index 0000000..2cda121 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/ConditionalInterruptibleASTTransformation.groovy @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform + +import groovy.transform.ConditionalInterrupt +import org.codehaus.groovy.ast.AnnotatedNode +import org.codehaus.groovy.ast.AnnotationNode +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.FieldNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.PropertyNode +import org.codehaus.groovy.ast.expr.ArgumentListExpression +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.VariableExpression +import org.codehaus.groovy.ast.tools.ClosureUtils +import org.codehaus.groovy.control.CompilePhase + +/** + * Allows "interrupt-safe" executions of scripts by adding a custom conditional + * check on loops (for, while, do) and first statement of closures. By default, also adds an interrupt check + * statement on the beginning of method calls. + * + * @see groovy.transform.ConditionalInterrupt + * @author Cedric Champeau + * @author Hamlet D'Arcy + * @author Paul King + * @since 1.8.0 + */ +@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) +public class ConditionalInterruptibleASTTransformation extends AbstractInterruptibleASTTransformation { + + private static final ClassNode MY_TYPE = ClassHelper.make(ConditionalInterrupt) + + private ClosureExpression conditionNode + private String conditionMethod + private MethodCallExpression conditionCallExpression + private ClassNode currentClass + + protected ClassNode type() { + return MY_TYPE + } + + protected void setupTransform(AnnotationNode node) { + super.setupTransform(node) + def member = node.getMember("value") + if (!member || !(member instanceof ClosureExpression)) internalError("Expected closure value for annotation parameter 'value'. Found $member") + conditionNode = member; + conditionMethod = 'conditionalTransform' + node.hashCode() + '$condition' + conditionCallExpression = new MethodCallExpression(new VariableExpression('this'), conditionMethod, new ArgumentListExpression()) + } + + protected String getErrorMessage() { + 'Execution interrupted. The following condition failed: ' + convertClosureToSource(conditionNode) + } + + void visitClass(ClassNode type) { + currentClass = type + def method = type.addMethod(conditionMethod, ACC_PRIVATE | ACC_SYNTHETIC, ClassHelper.OBJECT_TYPE, Parameter.EMPTY_ARRAY, ClassNode.EMPTY_ARRAY, conditionNode.code) + method.synthetic = true + if (applyToAllMembers) { + super.visitClass(type) + } + } + + protected Expression createCondition() { + conditionCallExpression + } + + @Override + void visitAnnotations(AnnotatedNode node) { + // this transformation does not apply on annotation nodes + // visiting could lead to stack overflows + } + + @Override + void visitField(FieldNode node) { + if (!node.isStatic() && !node.isSynthetic()) { + super.visitField node + } + } + + @Override + void visitProperty(PropertyNode node) { + if (!node.isStatic() && !node.isSynthetic()) { + super.visitProperty node + } + } + + @Override + void visitClosureExpression(ClosureExpression closureExpr) { + if (closureExpr == conditionNode) return // do not visit the closure from the annotation itself + def code = closureExpr.code + closureExpr.code = wrapBlock(code) + super.visitClosureExpression closureExpr + } + + @Override + void visitMethod(MethodNode node) { + if (node.name == conditionMethod && !node.isSynthetic()) return // do not visit the generated method + if (node.name == 'run' && currentClass.isScript() && node.parameters.length == 0) { + // the run() method should not have the statement added, otherwise the script binding won't be set before + // the condition is actually tested + super.visitMethod(node) + } else { + if (checkOnMethodStart && !node.isSynthetic() && !node.isStatic() && !node.isAbstract()) { + def code = node.code + node.code = wrapBlock(code); + } + if (!node.isSynthetic() && !node.isStatic()) super.visitMethod(node) + } + } + + /** + * Converts a ClosureExpression into the String source. + * @param expression a closure + * @return the source the closure was created from + */ + private String convertClosureToSource(ClosureExpression expression) { + try { + return ClosureUtils.convertClosureToSource(this.source.source, expression); + } catch(Exception e) { + return e.message + } + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/ThreadInterruptibleASTTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/ThreadInterruptibleASTTransformation.groovy b/src/main/groovy/org/codehaus/groovy/transform/ThreadInterruptibleASTTransformation.groovy new file mode 100644 index 0000000..a4fb4c3 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/ThreadInterruptibleASTTransformation.groovy @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform + +import groovy.transform.CompileStatic +import groovy.transform.ThreadInterrupt +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.expr.ArgumentListExpression +import org.codehaus.groovy.ast.expr.ClassExpression +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.control.CompilePhase + +/** + * Allows "interrupt-safe" executions of scripts by adding Thread.currentThread().isInterrupted() + * checks on loops (for, while, do) and first statement of closures. By default, also adds an interrupt check + * statement on the beginning of method calls. + * + * @see groovy.transform.ThreadInterrupt + * + * @author Cedric Champeau + * @author Hamlet D'Arcy + * + * @since 1.8.0 + */ +@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) +@CompileStatic +public class ThreadInterruptibleASTTransformation extends AbstractInterruptibleASTTransformation { + + private static final ClassNode MY_TYPE = ClassHelper.make(ThreadInterrupt) + private static final ClassNode THREAD_TYPE = ClassHelper.make(Thread) + private static final MethodNode CURRENTTHREAD_METHOD + private static final MethodNode ISINTERRUPTED_METHOD + + static { + CURRENTTHREAD_METHOD = THREAD_TYPE.getMethod('currentThread', Parameter.EMPTY_ARRAY) + ISINTERRUPTED_METHOD = THREAD_TYPE.getMethod('isInterrupted', Parameter.EMPTY_ARRAY) + } + + protected ClassNode type() { + return MY_TYPE; + } + + protected String getErrorMessage() { + 'Execution interrupted. The current thread has been interrupted.' + } + + protected Expression createCondition() { + def currentThread = new MethodCallExpression(new ClassExpression(THREAD_TYPE), + 'currentThread', + ArgumentListExpression.EMPTY_ARGUMENTS) + currentThread.methodTarget = CURRENTTHREAD_METHOD + def isInterrupted = new MethodCallExpression( + currentThread, + 'isInterrupted', ArgumentListExpression.EMPTY_ARGUMENTS) + isInterrupted.methodTarget = ISINTERRUPTED_METHOD + [currentThread, isInterrupted]*.implicitThis = false + + isInterrupted + } + + + @Override + public void visitClosureExpression(ClosureExpression closureExpr) { + def code = closureExpr.code + closureExpr.code = wrapBlock(code) + super.visitClosureExpression closureExpr + } + + @Override + public void visitMethod(MethodNode node) { + if (checkOnMethodStart && !node.isSynthetic() && !node.isAbstract()) { + def code = node.code + node.code = wrapBlock(code); + } + super.visitMethod(node) + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy b/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy new file mode 100644 index 0000000..fbc923b --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform + +import groovy.transform.TimedInterrupt +import org.codehaus.groovy.ast.ASTNode +import org.codehaus.groovy.ast.AnnotatedNode +import org.codehaus.groovy.ast.AnnotationNode +import org.codehaus.groovy.ast.ClassCodeVisitorSupport +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.FieldNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.PropertyNode +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.expr.ConstantExpression +import org.codehaus.groovy.ast.expr.DeclarationExpression +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.DoWhileStatement +import org.codehaus.groovy.ast.stmt.ForStatement +import org.codehaus.groovy.ast.stmt.WhileStatement +import org.codehaus.groovy.control.CompilePhase +import org.codehaus.groovy.control.SourceUnit + +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException + +import static org.codehaus.groovy.ast.ClassHelper.make +import static org.codehaus.groovy.ast.tools.GeneralUtils.args +import static org.codehaus.groovy.ast.tools.GeneralUtils.callX +import static org.codehaus.groovy.ast.tools.GeneralUtils.classX +import static org.codehaus.groovy.ast.tools.GeneralUtils.constX +import static org.codehaus.groovy.ast.tools.GeneralUtils.ctorX +import static org.codehaus.groovy.ast.tools.GeneralUtils.ifS +import static org.codehaus.groovy.ast.tools.GeneralUtils.ltX +import static org.codehaus.groovy.ast.tools.GeneralUtils.plusX +import static org.codehaus.groovy.ast.tools.GeneralUtils.propX +import static org.codehaus.groovy.ast.tools.GeneralUtils.throwS +import static org.codehaus.groovy.ast.tools.GeneralUtils.varX + +/** + * Allows "interrupt-safe" executions of scripts by adding timer expiration + * checks on loops (for, while, do) and first statement of closures. By default, + * also adds an interrupt check statement on the beginning of method calls. + * + * @author Cedric Champeau + * @author Hamlet D'Arcy + * @author Paul King + * @see groovy.transform.ThreadInterrupt + * @since 1.8.0 + */ +@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) +public class TimedInterruptibleASTTransformation extends AbstractASTTransformation { + + private static final ClassNode MY_TYPE = make(TimedInterrupt) + private static final String CHECK_METHOD_START_MEMBER = 'checkOnMethodStart' + private static final String APPLY_TO_ALL_CLASSES = 'applyToAllClasses' + private static final String APPLY_TO_ALL_MEMBERS = 'applyToAllMembers' + private static final String THROWN_EXCEPTION_TYPE = "thrown" + + public void visit(ASTNode[] nodes, SourceUnit source) { + init(nodes, source); + AnnotationNode node = nodes[0] + AnnotatedNode annotatedNode = nodes[1] + if (!MY_TYPE.equals(node.getClassNode())) { + internalError("Transformation called from wrong annotation: $node.classNode.name") + } + + def checkOnMethodStart = getConstantAnnotationParameter(node, CHECK_METHOD_START_MEMBER, Boolean.TYPE, true) + def applyToAllMembers = getConstantAnnotationParameter(node, APPLY_TO_ALL_MEMBERS, Boolean.TYPE, true) + def applyToAllClasses = applyToAllMembers ? getConstantAnnotationParameter(node, APPLY_TO_ALL_CLASSES, Boolean.TYPE, true) : false + def maximum = getConstantAnnotationParameter(node, 'value', Long.TYPE, Long.MAX_VALUE) + def thrown = AbstractInterruptibleASTTransformation.getClassAnnotationParameter(node, THROWN_EXCEPTION_TYPE, make(TimeoutException)) + + Expression unit = node.getMember('unit') ?: propX(classX(TimeUnit), "SECONDS") + + // should be limited to the current SourceUnit or propagated to the whole CompilationUnit + // DO NOT inline visitor creation in code below. It has state that must not persist between calls + if (applyToAllClasses) { + // guard every class and method defined in this script + source.getAST()?.classes?.each { ClassNode it -> + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitClass(it) + } + } else if (annotatedNode instanceof ClassNode) { + // only guard this particular class + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitClass annotatedNode + } else if (!applyToAllMembers && annotatedNode instanceof MethodNode) { + // only guard this particular method (plus initCode for class) + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitMethod annotatedNode + visitor.visitClass annotatedNode.declaringClass + } else if (!applyToAllMembers && annotatedNode instanceof FieldNode) { + // only guard this particular field (plus initCode for class) + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitField annotatedNode + visitor.visitClass annotatedNode.declaringClass + } else if (!applyToAllMembers && annotatedNode instanceof DeclarationExpression) { + // only guard this particular declaration (plus initCode for class) + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitDeclarationExpression annotatedNode + visitor.visitClass annotatedNode.declaringClass + } else { + // only guard the script class + source.getAST()?.classes?.each { ClassNode it -> + if (it.isScript()) { + def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) + visitor.visitClass(it) + } + } + } + } + + static def getConstantAnnotationParameter(AnnotationNode node, String parameterName, Class type, defaultValue) { + def member = node.getMember(parameterName) + if (member) { + if (member instanceof ConstantExpression) { + // TODO not sure this try offers value - testing Groovy annotation type handing - throw GroovyBugError or remove? + try { + return member.value.asType(type) + } catch (ignore) { + internalError("Expecting boolean value for ${parameterName} annotation parameter. Found $member") + } + } else { + internalError("Expecting boolean value for ${parameterName} annotation parameter. Found $member") + } + } + return defaultValue + } + + private static void internalError(String message) { + throw new RuntimeException("Internal error: $message") + } + + private static class TimedInterruptionVisitor extends ClassCodeVisitorSupport { + final private SourceUnit source + final private boolean checkOnMethodStart + final private boolean applyToAllClasses + final private boolean applyToAllMembers + private FieldNode expireTimeField = null + private FieldNode startTimeField = null + private final Expression unit + private final maximum + private final ClassNode thrown + private final String basename + + TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, hash) { + this.source = source + this.checkOnMethodStart = checkOnMethodStart + this.applyToAllClasses = applyToAllClasses + this.applyToAllMembers = applyToAllMembers + this.unit = unit + this.maximum = maximum + this.thrown = thrown + this.basename = 'timedInterrupt' + hash + } + + /** + * @return Returns the interruption check statement. + */ + final createInterruptStatement() { + ifS( + + ltX( + propX(varX("this"), basename + '$expireTime'), + callX(make(System), 'nanoTime') + ), + throwS( + ctorX(thrown, + args( + plusX( + plusX( + constX('Execution timed out after ' + maximum + ' '), + callX(callX(unit, 'name'), 'toLowerCase', propX(classX(Locale), 'US')) + ), + plusX( + constX('. Start time: '), + propX(varX("this"), basename + '$startTime') + ) + ) + + ) + ) + ) + ) + } + + /** + * Takes a statement and wraps it into a block statement which first element is the interruption check statement. + * @param statement the statement to be wrapped + * @return a {@link BlockStatement block statement} which first element is for checking interruption, and the + * second one the statement to be wrapped. + */ + private wrapBlock(statement) { + def stmt = new BlockStatement(); + stmt.addStatement(createInterruptStatement()); + stmt.addStatement(statement); + stmt + } + + @Override + void visitClass(ClassNode node) { + if (node.getDeclaredField(basename + '$expireTime')) { + return + } + expireTimeField = node.addField(basename + '$expireTime', + ACC_FINAL | ACC_PRIVATE, + ClassHelper.long_TYPE, + plusX( + callX(make(System), 'nanoTime'), + callX( + propX(classX(TimeUnit), 'NANOSECONDS'), + 'convert', + args(constX(maximum, true), unit) + ) + ) + ); + expireTimeField.synthetic = true + startTimeField = node.addField(basename + '$startTime', + ACC_FINAL | ACC_PRIVATE, + make(Date), + ctorX(make(Date)) + ) + startTimeField.synthetic = true + + // force these fields to be initialized first + node.fields.remove(expireTimeField) + node.fields.remove(startTimeField) + node.fields.add(0, startTimeField) + node.fields.add(0, expireTimeField) + if (applyToAllMembers) { + super.visitClass node + } + } + + @Override + void visitClosureExpression(ClosureExpression closureExpr) { + def code = closureExpr.code + if (code instanceof BlockStatement) { + code.statements.add(0, createInterruptStatement()) + } else { + closureExpr.code = wrapBlock(code) + } + super.visitClosureExpression closureExpr + } + + @Override + void visitField(FieldNode node) { + if (!node.isStatic() && !node.isSynthetic()) { + super.visitField node + } + } + + @Override + void visitProperty(PropertyNode node) { + if (!node.isStatic() && !node.isSynthetic()) { + super.visitProperty node + } + } + + /** + * Shortcut method which avoids duplicating code for every type of loop. + * Actually wraps the loopBlock of different types of loop statements. + */ + private visitLoop(loopStatement) { + def statement = loopStatement.loopBlock + loopStatement.loopBlock = wrapBlock(statement) + } + + @Override + void visitForLoop(ForStatement forStatement) { + visitLoop(forStatement) + super.visitForLoop(forStatement) + } + + @Override + void visitDoWhileLoop(final DoWhileStatement doWhileStatement) { + visitLoop(doWhileStatement) + super.visitDoWhileLoop(doWhileStatement) + } + + @Override + void visitWhileLoop(final WhileStatement whileStatement) { + visitLoop(whileStatement) + super.visitWhileLoop(whileStatement) + } + + @Override + void visitMethod(MethodNode node) { + if (checkOnMethodStart && !node.isSynthetic() && !node.isStatic() && !node.isAbstract()) { + def code = node.code + node.code = wrapBlock(code); + } + if (!node.isSynthetic() && !node.isStatic()) { + super.visitMethod(node) + } + } + + protected SourceUnit getSourceUnit() { + return source; + } + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/tailrec/AstHelper.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/AstHelper.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/AstHelper.groovy new file mode 100644 index 0000000..206b0df --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/AstHelper.groovy @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.VariableExpression +import org.codehaus.groovy.ast.stmt.ContinueStatement +import org.codehaus.groovy.ast.stmt.ExpressionStatement +import org.codehaus.groovy.ast.stmt.Statement +import org.codehaus.groovy.ast.stmt.ThrowStatement + +import java.lang.reflect.Modifier + +import static org.codehaus.groovy.ast.tools.GeneralUtils.classX +import static org.codehaus.groovy.ast.tools.GeneralUtils.declS +import static org.codehaus.groovy.ast.tools.GeneralUtils.propX +import static org.codehaus.groovy.ast.tools.GeneralUtils.varX + +/** + * Helping to create a few standard AST constructs + * + * @author Johannes Link + */ +@CompileStatic +class AstHelper { + static ExpressionStatement createVariableDefinition(String variableName, ClassNode variableType, Expression value, boolean variableShouldBeFinal = false ) { + def newVariable = varX(variableName, variableType) + if (variableShouldBeFinal) + newVariable.setModifiers(Modifier.FINAL) + (ExpressionStatement) declS(newVariable, value) + } + + static ExpressionStatement createVariableAlias(String aliasName, ClassNode variableType, String variableName ) { + createVariableDefinition(aliasName, variableType, varX(variableName, variableType)) + } + + static VariableExpression createVariableReference(Map variableSpec) { + varX((String) variableSpec.name, (ClassNode) variableSpec.type) + } + + /** + * This statement should make the code jump to surrounding while loop's start label + * Does not work from within Closures + */ + static Statement recurStatement() { + //continue _RECUR_HERE_ + new ContinueStatement(InWhileLoopWrapper.LOOP_LABEL) + } + + /** + * This statement will throw exception which will be caught and redirected to jump to surrounding while loop's start label + * Also works from within Closures but is a tiny bit slower + */ + static Statement recurByThrowStatement() { + // throw InWhileLoopWrapper.LOOP_EXCEPTION + new ThrowStatement(propX(classX(InWhileLoopWrapper), 'LOOP_EXCEPTION')) + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/tailrec/CollectRecursiveCalls.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/CollectRecursiveCalls.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/CollectRecursiveCalls.groovy new file mode 100644 index 0000000..2c7e6de --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/CollectRecursiveCalls.groovy @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression; + +/** + * Collect all recursive calls within method + * + * @author Johannes Link + */ +@CompileStatic +class CollectRecursiveCalls extends CodeVisitorSupport { + MethodNode method + List<Expression> recursiveCalls = [] + + public void visitMethodCallExpression(MethodCallExpression call) { + if (isRecursive(call)) { + recursiveCalls << call + } + super.visitMethodCallExpression(call) + } + + public void visitStaticMethodCallExpression(StaticMethodCallExpression call) { + if (isRecursive(call)) { + recursiveCalls << call + } + super.visitStaticMethodCallExpression(call) + } + + private boolean isRecursive(call) { + new RecursivenessTester().isRecursive(method: method, call: call) + } + + synchronized List<Expression> collect(MethodNode method) { + recursiveCalls.clear() + this.method = method + this.method.code.visit(this) + recursiveCalls + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/tailrec/HasRecursiveCalls.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/HasRecursiveCalls.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/HasRecursiveCalls.groovy new file mode 100644 index 0000000..79f8e6d --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/HasRecursiveCalls.groovy @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression + +/** + * + * Check if there are any recursive calls in a method + * + * @author Johannes Link + */ +@CompileStatic +class HasRecursiveCalls extends CodeVisitorSupport { + MethodNode method + boolean hasRecursiveCalls = false + + public void visitMethodCallExpression(MethodCallExpression call) { + if (isRecursive(call)) { + hasRecursiveCalls = true + } else { + super.visitMethodCallExpression(call) + } + } + + public void visitStaticMethodCallExpression(StaticMethodCallExpression call) { + if (isRecursive(call)) { + hasRecursiveCalls = true + } else { + super.visitStaticMethodCallExpression(call) + } + } + + private boolean isRecursive(call) { + new RecursivenessTester().isRecursive(method: method, call: call) + } + + synchronized boolean test(MethodNode method) { + hasRecursiveCalls = false + this.method = method + this.method.code.visit(this) + hasRecursiveCalls + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/tailrec/InWhileLoopWrapper.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/InWhileLoopWrapper.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/InWhileLoopWrapper.groovy new file mode 100644 index 0000000..981f146 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/InWhileLoopWrapper.groovy @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.VariableScope +import org.codehaus.groovy.ast.expr.BooleanExpression +import org.codehaus.groovy.ast.expr.ConstantExpression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.CatchStatement +import org.codehaus.groovy.ast.stmt.ContinueStatement +import org.codehaus.groovy.ast.stmt.EmptyStatement +import org.codehaus.groovy.ast.stmt.Statement +import org.codehaus.groovy.ast.stmt.TryCatchStatement +import org.codehaus.groovy.ast.stmt.WhileStatement + +/** + * Wrap the body of a method in a while loop, nested in a try-catch. + * This is the first step in making a tail recursive method iterative. + * + * There are two ways to invoke the next iteration step: + * 1. "continue _RECURE_HERE_" is used by recursive calls outside of closures + * 2. "throw LOOP_EXCEPTION" is used by recursive calls within closures b/c you cannot invoke "continue" from there + * + * @author Johannes Link + */ +@CompileStatic +class InWhileLoopWrapper { + + static final String LOOP_LABEL = '_RECUR_HERE_' + static final GotoRecurHereException LOOP_EXCEPTION = new GotoRecurHereException() + + void wrap(MethodNode method) { + BlockStatement oldBody = method.code as BlockStatement + TryCatchStatement tryCatchStatement = new TryCatchStatement( + oldBody, + EmptyStatement.INSTANCE + ) + tryCatchStatement.addCatch(new CatchStatement( + new Parameter(ClassHelper.make(GotoRecurHereException), 'ignore'), + new ContinueStatement(InWhileLoopWrapper.LOOP_LABEL) + )) + + WhileStatement whileLoop = new WhileStatement( + new BooleanExpression(new ConstantExpression(true)), + new BlockStatement([tryCatchStatement] as List<Statement>, new VariableScope(method.variableScope)) + ) + List<Statement> whileLoopStatements = ((BlockStatement) whileLoop.loopBlock).statements + if (whileLoopStatements.size() > 0) + whileLoopStatements[0].statementLabel = LOOP_LABEL + BlockStatement newBody = new BlockStatement([] as List<Statement>, new VariableScope(method.variableScope)) + newBody.addStatement(whileLoop) + method.code = newBody + } +} + +/** + * Exception will be thrown by recursive calls in closures and caught in while loop to continue to LOOP_LABEL + */ +class GotoRecurHereException extends Throwable { + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/tailrec/RecursivenessTester.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/RecursivenessTester.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/RecursivenessTester.groovy new file mode 100644 index 0000000..7c9545a --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/RecursivenessTester.groovy @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.expr.ConstantExpression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression +import org.codehaus.groovy.ast.expr.VariableExpression + +/** + * + * Test if a method call is recursive if called within a given method node. + * Handles static calls as well. + * + * Currently known simplifications: + * - Does not check for method overloading or overridden methods. + * - Does not check for matching return types; even void and any object type are considered to be compatible. + * - Argument type matching could be more specific in case of static compilation. + * - Method names via a GString are never considered to be recursive + * + * @author Johannes Link + */ +class RecursivenessTester { + public boolean isRecursive(params) { + assert params.method.class == MethodNode + assert params.call.class == MethodCallExpression || StaticMethodCallExpression + + isRecursive(params.method, params.call) + } + + public boolean isRecursive(MethodNode method, MethodCallExpression call) { + if (!isCallToThis(call)) + return false + // Could be a GStringExpression + if (! (call.method instanceof ConstantExpression)) + return false + if (call.method.value != method.name) + return false + methodParamsMatchCallArgs(method, call) + } + + public boolean isRecursive(MethodNode method, StaticMethodCallExpression call) { + if (!method.isStatic()) + return false + if (method.declaringClass != call.ownerType) + return false + if (call.method != method.name) + return false + methodParamsMatchCallArgs(method, call) + } + + private boolean isCallToThis(MethodCallExpression call) { + if (call.objectExpression == null) + return call.isImplicitThis() + if (! (call.objectExpression instanceof VariableExpression)) { + return false + } + return call.objectExpression.isThisExpression() + } + + private boolean methodParamsMatchCallArgs(method, call) { + if (method.parameters.size() != call.arguments.expressions.size()) + return false + def classNodePairs = [method.parameters*.type, call.arguments*.type].transpose() + return classNodePairs.every { ClassNode paramType, ClassNode argType -> + return areTypesCallCompatible(argType, paramType) + } + } + + /** + * Parameter type and calling argument type can both be derived from the other since typing information is + * optional in Groovy. + * Since int is not derived from Integer (nor the other way around) we compare the boxed types + */ + private areTypesCallCompatible(ClassNode argType, ClassNode paramType) { + ClassNode boxedArg = ClassHelper.getWrapper(argType) + ClassNode boxedParam = ClassHelper.getWrapper(paramType) + return boxedArg.isDerivedFrom(boxedParam) || boxedParam.isDerivedFrom(boxedArg) + } + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnAdderForClosures.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnAdderForClosures.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnAdderForClosures.groovy new file mode 100644 index 0000000..64ebce7 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnAdderForClosures.groovy @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.classgen.ReturnAdder + +/** + * Adds explicit return statements to implicit return points in a closure. This is necessary since + * tail-recursion is detected by having the recursive call within the return statement. + * + * @author Johannes Link + */ +class ReturnAdderForClosures extends CodeVisitorSupport { + + synchronized void visitMethod(MethodNode method) { + method.code.visit(this) + } + + public void visitClosureExpression(ClosureExpression expression) { + //Create a dummy method with the closure's code as the method's code. Then user ReturnAdder, which only works for methods. + MethodNode node = new MethodNode("dummy", 0, ClassHelper.OBJECT_TYPE, Parameter.EMPTY_ARRAY, ClassNode.EMPTY_ARRAY, expression.code); + new ReturnAdder().visitMethod(node); + super.visitClosureExpression(expression) + } + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnStatementToIterationConverter.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnStatementToIterationConverter.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnStatementToIterationConverter.groovy new file mode 100644 index 0000000..d489ff0 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnStatementToIterationConverter.groovy @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.expr.BinaryExpression +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression +import org.codehaus.groovy.ast.expr.TupleExpression +import org.codehaus.groovy.ast.expr.VariableExpression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.ExpressionStatement +import org.codehaus.groovy.ast.stmt.ReturnStatement +import org.codehaus.groovy.ast.stmt.Statement + +import static org.codehaus.groovy.ast.tools.GeneralUtils.assignS +import static org.codehaus.groovy.ast.tools.GeneralUtils.varX + +/** + * Translates all return statements into an invocation of the next iteration. This can be either + * - "continue LOOP_LABEL": Outside closures + * - "throw LOOP_EXCEPTION": Inside closures + * + * Moreover, before adding the recur statement the iteration parameters (originally the method args) + * are set to their new value. To prevent variable aliasing parameters will be copied into temp vars + * before they are changes so that their current iteration value can be used when setting other params. + * + * There's probably place for optimizing the amount of variable copying being done, e.g. + * parameters that are only handed through must not be copied at all. + * + * @author Johannes Link + */ +@CompileStatic +class ReturnStatementToIterationConverter { + + Statement recurStatement = AstHelper.recurStatement() + + Statement convert(ReturnStatement statement, Map<Integer, Map> positionMapping) { + Expression recursiveCall = statement.expression + if (!isAMethodCalls(recursiveCall)) + return statement + + Map<String, Map> tempMapping = [:] + Map tempDeclarations = [:] + List<ExpressionStatement> argAssignments = [] + + BlockStatement result = new BlockStatement() + result.statementLabel = statement.statementLabel + + /* Create temp declarations for all method arguments. + * Add the declarations and var mapping to tempMapping and tempDeclarations for further reference. + */ + getArguments(recursiveCall).eachWithIndex { Expression expression, int index -> + ExpressionStatement tempDeclaration = createTempDeclaration(index, positionMapping, tempMapping, tempDeclarations) + result.addStatement(tempDeclaration) + } + + /* + * Assign the iteration variables their new value before recuring + */ + getArguments(recursiveCall).eachWithIndex { Expression expression, int index -> + ExpressionStatement argAssignment = createAssignmentToIterationVariable(expression, index, positionMapping) + argAssignments.add(argAssignment) + result.addStatement(argAssignment) + } + + Set<String> unusedTemps = replaceAllArgUsages(argAssignments, tempMapping) + for (String temp : unusedTemps) { + result.statements.remove(tempDeclarations[temp]) + } + result.addStatement(recurStatement) + + return result + } + + private ExpressionStatement createAssignmentToIterationVariable(Expression expression, int index, Map<Integer, Map> positionMapping) { + String argName = positionMapping[index]['name'] + ClassNode argAndTempType = positionMapping[index]['type'] as ClassNode + ExpressionStatement argAssignment = (ExpressionStatement) assignS(varX(argName, argAndTempType), expression) + argAssignment + } + + private ExpressionStatement createTempDeclaration(int index, Map<Integer, Map> positionMapping, Map<String, Map> tempMapping, Map tempDeclarations) { + String argName = positionMapping[index]['name'] + String tempName = "_${argName}_" + ClassNode argAndTempType = positionMapping[index]['type'] as ClassNode + ExpressionStatement tempDeclaration = AstHelper.createVariableAlias(tempName, argAndTempType, argName) + tempMapping[argName] = [name: tempName, type: argAndTempType] + tempDeclarations[tempName] = tempDeclaration + return tempDeclaration + } + + private List<Expression> getArguments(Expression recursiveCall) { + if (recursiveCall instanceof MethodCallExpression) + return ((TupleExpression) ((MethodCallExpression) recursiveCall).arguments).expressions + if (recursiveCall instanceof StaticMethodCallExpression) + return ((TupleExpression) ((StaticMethodCallExpression) recursiveCall).arguments).expressions + } + + private boolean isAMethodCalls(Expression expression) { + expression.class in [MethodCallExpression, StaticMethodCallExpression] + } + + private Set<String> replaceAllArgUsages(List<ExpressionStatement> iterationVariablesAssignmentNodes, Map<String, Map> tempMapping) { + Set<String> unusedTempNames = tempMapping.values().collect {Map nameAndType -> (String) nameAndType['name']} as Set<String> + VariableReplacedListener tracker = new UsedVariableTracker() + for (ExpressionStatement statement : iterationVariablesAssignmentNodes) { + replaceArgUsageByTempUsage((BinaryExpression) statement.expression, tempMapping, tracker) + } + unusedTempNames = unusedTempNames - tracker.usedVariableNames + return unusedTempNames + } + + private void replaceArgUsageByTempUsage(BinaryExpression binary, Map tempMapping, UsedVariableTracker tracker) { + VariableAccessReplacer replacer = new VariableAccessReplacer(nameAndTypeMapping: tempMapping, listener: tracker) + // Replacement must only happen in binary.rightExpression. It's a hack in VariableExpressionReplacer which takes care of that. + replacer.replaceIn(binary) + } +} + +@CompileStatic +class UsedVariableTracker implements org.codehaus.groovy.transform.tailrec.VariableReplacedListener { + + final Set<String> usedVariableNames = [] as Set + + @Override + void variableReplaced(VariableExpression oldVar, VariableExpression newVar) { + usedVariableNames.add(newVar.name) + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/d638ca43/src/main/groovy/org/codehaus/groovy/transform/tailrec/StatementReplacer.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/StatementReplacer.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/StatementReplacer.groovy new file mode 100644 index 0000000..3a9dab3 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/StatementReplacer.groovy @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ASTNode +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.DoWhileStatement +import org.codehaus.groovy.ast.stmt.ForStatement +import org.codehaus.groovy.ast.stmt.IfStatement +import org.codehaus.groovy.ast.stmt.Statement +import org.codehaus.groovy.ast.stmt.WhileStatement + +/** + * Tool for replacing Statement objects in an AST by other Statement instances. + * + * Within @TailRecursive it is used to swap ReturnStatements with looping back to RECUR label + * + * @author Johannes Link + */ +@CompileStatic +class StatementReplacer extends CodeVisitorSupport { + + Closure<Boolean> when = { Statement node -> false } + Closure<Statement> replaceWith = { Statement statement -> statement } + int closureLevel = 0 + + void replaceIn(ASTNode root) { + root.visit(this) + } + + public void visitClosureExpression(ClosureExpression expression) { + closureLevel++ + try { + super.visitClosureExpression(expression) + } finally { + closureLevel-- + } + } + + public void visitBlockStatement(BlockStatement block) { + List<Statement> copyOfStatements = new ArrayList<Statement>(block.statements) + copyOfStatements.eachWithIndex { Statement statement, int index -> + replaceIfNecessary(statement) { Statement node -> block.statements[index] = node } + } + super.visitBlockStatement(block); + } + + public void visitIfElse(IfStatement ifElse) { + replaceIfNecessary(ifElse.ifBlock) { Statement s -> ifElse.ifBlock = s } + replaceIfNecessary(ifElse.elseBlock) { Statement s -> ifElse.elseBlock = s } + super.visitIfElse(ifElse); + } + + public void visitForLoop(ForStatement forLoop) { + replaceIfNecessary(forLoop.loopBlock) { Statement s -> forLoop.loopBlock = s } + super.visitForLoop(forLoop); + } + + public void visitWhileLoop(WhileStatement loop) { + replaceIfNecessary(loop.loopBlock) { Statement s -> loop.loopBlock = s } + super.visitWhileLoop(loop); + } + + public void visitDoWhileLoop(DoWhileStatement loop) { + replaceIfNecessary(loop.loopBlock) { Statement s -> loop.loopBlock = s } + super.visitDoWhileLoop(loop); + } + + + private void replaceIfNecessary(Statement nodeToCheck, Closure replacementCode) { + if (conditionFulfilled(nodeToCheck)) { + ASTNode replacement = replaceWith(nodeToCheck) + replacement.setSourcePosition(nodeToCheck); + replacement.copyNodeMetaData(nodeToCheck); + replacementCode(replacement) + } + } + + private boolean conditionFulfilled(ASTNode nodeToCheck) { + if (when.maximumNumberOfParameters < 2) + return when(nodeToCheck) + else + return when(nodeToCheck, isInClosure()) + } + + private boolean isInClosure() { + closureLevel > 0 + } + +}
