http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/tools/GrapeMain.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/tools/GrapeMain.groovy b/src/main/groovy/org/codehaus/groovy/tools/GrapeMain.groovy new file mode 100644 index 0000000..2f53afa --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/tools/GrapeMain.groovy @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.tools + +import groovy.grape.Grape +import groovy.transform.Field +import org.apache.commons.cli.CommandLine +import org.apache.commons.cli.DefaultParser +import org.apache.commons.cli.HelpFormatter +import org.apache.commons.cli.Option +import org.apache.commons.cli.OptionGroup +import org.apache.commons.cli.Options +import org.apache.ivy.util.DefaultMessageLogger +import org.apache.ivy.util.Message + +//commands + +@Field install = {arg, cmd -> + if (arg.size() > 5 || arg.size() < 3) { + println 'install requires two to four arguments: <group> <module> [<version> [<classifier>]]' + return + } + def ver = '*' + if (arg.size() >= 4) { + ver = arg[3] + } + def classifier = null + if (arg.size() >= 5) { + classifier = arg[4] + } + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging() + + cmd.getOptionValues('r')?.each { String url -> + Grape.addResolver(name:url, root:url) + } + + try { + Grape.grab(autoDownload: true, group: arg[1], module: arg[2], version: ver, classifier: classifier, noExceptions: true) + } catch (Exception e) { + println "An error occured : $ex" + } +} + +@Field uninstall = {arg, cmd -> + if (arg.size() != 4) { + println 'uninstall requires three arguments: <group> <module> <version>' + // TODO make version optional? support classifier? +// println 'uninstall requires two to four arguments, <group> <module> [<version>] [<classifier>]' + return + } + String group = arg[1] + String module = arg[2] + String ver = arg[3] +// def classifier = null + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging() + + if (!Grape.enumerateGrapes().find {String groupName, Map g -> + g.any {String moduleName, List<String> versions -> + group == groupName && module == moduleName && ver in versions + } + }) { + println "uninstall did not find grape matching: $group $module $ver" + def fuzzyMatches = Grape.enumerateGrapes().findAll { String groupName, Map g -> + g.any {String moduleName, List<String> versions -> + groupName.contains(group) || moduleName.contains(module) || + group.contains(groupName) || module.contains(moduleName) + } + } + if (fuzzyMatches) { + println 'possible matches:' + fuzzyMatches.each { String groupName, Map g -> println " $groupName: $g" } + } + return + } + Grape.instance.uninstallArtifact(group, module, ver) +} + +@Field list = {arg, cmd -> + println "" + + int moduleCount = 0 + int versionCount = 0 + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging() + + Grape.enumerateGrapes().each {String groupName, Map group -> + group.each {String moduleName, List<String> versions -> + println "$groupName $moduleName $versions" + moduleCount++ + versionCount += versions.size() + } + } + println "" + println "$moduleCount Grape modules cached" + println "$versionCount Grape module versions cached" +} + +@Field resolve = {arg, cmd -> + Options options = new Options(); + options.addOption(Option.builder("a").hasArg(false).longOpt("ant").build()); + options.addOption(Option.builder("d").hasArg(false).longOpt("dos").build()); + options.addOption(Option.builder("s").hasArg(false).longOpt("shell").build()); + options.addOption(Option.builder("i").hasArg(false).longOpt("ivy").build()); + CommandLine cmd2 = new DefaultParser().parse(options, arg[1..-1] as String[], true); + arg = cmd2.args + + // set the instance so we can re-set the logger + Grape.getInstance() + setupLogging(Message.MSG_ERR) + + if ((arg.size() % 3) != 0) { + println 'There needs to be a multiple of three arguments: (group module version)+' + return + } + if (args.size() < 3) { + println 'At least one Grape reference is required' + return + } + def before, between, after + def ivyFormatRequested = false + + if (cmd2.hasOption('a')) { + before = '<pathelement location="' + between = '">\n<pathelement location="' + after = '">' + } else if (cmd2.hasOption('d')) { + before = 'set CLASSPATH=' + between = ';' + after = '' + } else if (cmd2.hasOption('s')) { + before = 'export CLASSPATH=' + between = ':' + after = '' + } else if (cmd2.hasOption('i')) { + ivyFormatRequested = true + before = '<dependency ' + between = '">\n<dependency ' + after = '">' + } else { + before = '' + between = '\n' + after = '\n' + } + + iter = arg.iterator() + def params = [[:]] + def depsInfo = [] // this list will contain the module/group/version info of all resolved dependencies + if(ivyFormatRequested) { + params << depsInfo + } + while (iter.hasNext()) { + params.add([group: iter.next(), module: iter.next(), version: iter.next()]) + } + try { + def results = [] + def uris = Grape.resolve(* params) + if(!ivyFormatRequested) { + for (URI uri: uris) { + if (uri.scheme == 'file') { + results += new File(uri).path + } else { + results += uri.toASCIIString() + } + } + } else { + depsInfo.each { dep -> + results += ('org="' + dep.group + '" name="' + dep.module + '" revision="' + dep.revision) + } + } + + if (results) { + println "${before}${results.join(between)}${after}" + } else { + println 'Nothing was resolved' + } + } catch (Exception e) { + println "Error in resolve:\n\t$e.message" + if (e.message =~ /unresolved dependency/) println "Perhaps the grape is not installed?" + } +} + +@Field help = { arg, cmd -> grapeHelp() } + +@Field commands = [ + 'install': [closure: install, + shortHelp: 'Installs a particular grape'], + 'uninstall': [closure: uninstall, + shortHelp: 'Uninstalls a particular grape (non-transitively removes the respective jar file from the grape cache)'], + 'list': [closure: list, + shortHelp: 'Lists all installed grapes'], + 'resolve': [closure: resolve, + shortHelp: 'Enumerates the jars used by a grape'], + 'help': [closure: help, + shortHelp: 'Usage information'] +] + +@Field grapeHelp = { + int spacesLen = commands.keySet().max {it.length()}.length() + 3 + String spaces = ' ' * spacesLen + + PrintWriter pw = new PrintWriter(binding.variables.out ?: System.out) + new HelpFormatter().printHelp( + pw, + 80, + "grape [options] <command> [args]\n", + "options:", + options, + 2, + 4, + null, // footer + true); + pw.flush() + + println "" + println "commands:" + commands.each {String k, v -> + println " ${(k + spaces).substring(0, spacesLen)} $v.shortHelp" + } + println "" +} + +@Field setupLogging = {int defaultLevel = 2 -> // = Message.MSG_INFO -> some parsing error :( + if (cmd.hasOption('q')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_ERR)) + } else if (cmd.hasOption('w')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_WARN)) + } else if (cmd.hasOption('i')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_INFO)) + } else if (cmd.hasOption('V')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_VERBOSE)) + } else if (cmd.hasOption('d')) { + Message.setDefaultLogger(new DefaultMessageLogger(Message.MSG_DEBUG)) + } else { + Message.setDefaultLogger(new DefaultMessageLogger(defaultLevel)) + } +} + +// command line parsing +@Field Options options = new Options(); + +options.addOption(Option.builder("D").longOpt("define").desc("define a system property").numberOfArgs(2).valueSeparator().argName("name=value").build()); +options.addOption(Option.builder("r").longOpt("resolver").desc("define a grab resolver (for install)").hasArg(true).argName("url").build()); +options.addOption(Option.builder("h").hasArg(false).desc("usage information").longOpt("help").build()); + +// Logging Level Options +options.addOptionGroup( + new OptionGroup() + .addOption(Option.builder("q").hasArg(false).desc("Log level 0 - only errors").longOpt("quiet").build()) + .addOption(Option.builder("w").hasArg(false).desc("Log level 1 - errors and warnings").longOpt("warn").build()) + .addOption(Option.builder("i").hasArg(false).desc("Log level 2 - info").longOpt("info").build()) + .addOption(Option.builder("V").hasArg(false).desc("Log level 3 - verbose").longOpt("verbose").build()) + .addOption(Option.builder("d").hasArg(false).desc("Log level 4 - debug").longOpt("debug").build()) +) +options.addOption(Option.builder("v").hasArg(false).desc("display the Groovy and JVM versions").longOpt("version").build()); + +@Field CommandLine cmd + +cmd = new DefaultParser().parse(options, args, true); + +if (cmd.hasOption('h')) { + grapeHelp() + return +} + +if (cmd.hasOption('v')) { + String version = GroovySystem.getVersion(); + println "Groovy Version: $version JVM: ${System.getProperty('java.version')}" + return +} + +if (options.hasOption('D')) { + options.getOptionProperties('D')?.each { k, v -> + System.setProperty(k, v) + } +} + +String[] arg = cmd.args +if (arg?.length == 0) { + grapeHelp() +} else if (commands.containsKey(arg[0])) { + commands[arg[0]].closure(arg, cmd) +} else { + println "grape: '${arg[0]}' is not a grape command. See 'grape --help'" +}
http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/tools/ast/TransformTestHelper.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/tools/ast/TransformTestHelper.groovy b/src/main/groovy/org/codehaus/groovy/tools/ast/TransformTestHelper.groovy new file mode 100644 index 0000000..ce98d58 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/tools/ast/TransformTestHelper.groovy @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.tools.ast + +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.classgen.GeneratorContext +import org.codehaus.groovy.control.CompilationUnit +import org.codehaus.groovy.control.CompilationUnit.PrimaryClassNodeOperation +import org.codehaus.groovy.control.CompilePhase +import org.codehaus.groovy.control.CompilerConfiguration +import org.codehaus.groovy.control.SourceUnit +import org.codehaus.groovy.transform.ASTTransformation + +import java.security.CodeSource + +/* +* This TestHarness exists so that a global transform can be run without +* using the Jar services mechanism, which requires building a jar. +* +* To use this simply create an instance of TransformTestHelper with +* an ASTTransformation and CompilePhase, then invoke parse(File) or +* parse(String). +* +* This test harness is not exactly the same as executing a global transformation +* but can greatly aide in debugging and testing a transform. You should still +* test your global transformation when packaged as a jar service before +* releasing it. +* +* @author Hamlet D'Arcy +*/ +class TransformTestHelper { + + private final ASTTransformation transform + private final CompilePhase phase + + /** + * Creates the test helper. + * @param transform + * the transform to run when compiling the file later + * @param phase + * the phase to run the transform in + */ + def TransformTestHelper(ASTTransformation transform, CompilePhase phase) { + this.transform = transform + this.phase = phase + } + + /** + * Compiles the File into a Class applying the transform specified in the constructor. + * @input input + * must be a groovy source file + */ + public Class parse(File input) { + TestHarnessClassLoader loader = new TestHarnessClassLoader(transform, phase) + return loader.parseClass(input) + } + + /** + * Compiles the String into a Class applying the transform specified in the constructor. + * @input input + * must be a valid groovy source string + */ + public Class parse(String input) { + TestHarnessClassLoader loader = new TestHarnessClassLoader(transform, phase) + return loader.parseClass(input) + } +} + +/** +* ClassLoader exists so that TestHarnessOperation can be wired into the compile. +* +* @author Hamlet D'Arcy +*/ [email protected] class TestHarnessClassLoader extends GroovyClassLoader { + + private final ASTTransformation transform + private final CompilePhase phase + + TestHarnessClassLoader(ASTTransformation transform, CompilePhase phase) { + this.transform = transform + this.phase = phase + } + + protected CompilationUnit createCompilationUnit(CompilerConfiguration config, CodeSource codeSource) { + CompilationUnit cu = super.createCompilationUnit(config, codeSource) + cu.addPhaseOperation(new TestHarnessOperation(transform), phase.getPhaseNumber()) + return cu + } +} + +/** +* Operation exists so that an AstTransformation can be run against the SourceUnit. +* +* @author Hamlet D'Arcy +*/ [email protected] class TestHarnessOperation extends PrimaryClassNodeOperation { + + private final ASTTransformation transform + + def TestHarnessOperation(transform) { + this.transform = transform; + } + + public void call(SourceUnit source, GeneratorContext context, ClassNode classNode) { + transform.visit(null, source) + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/ASTTestTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/ASTTestTransformation.groovy b/src/main/groovy/org/codehaus/groovy/transform/ASTTestTransformation.groovy deleted file mode 100644 index dcfe314..0000000 --- a/src/main/groovy/org/codehaus/groovy/transform/ASTTestTransformation.groovy +++ /dev/null @@ -1,233 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.codehaus.groovy.transform - -import groovy.transform.CompilationUnitAware -import org.codehaus.groovy.ast.ASTNode -import org.codehaus.groovy.ast.AnnotationNode -import org.codehaus.groovy.ast.ClassCodeVisitorSupport -import org.codehaus.groovy.ast.ClassHelper -import org.codehaus.groovy.ast.ClassNode -import org.codehaus.groovy.ast.MethodNode -import org.codehaus.groovy.ast.expr.ClosureExpression -import org.codehaus.groovy.ast.expr.PropertyExpression -import org.codehaus.groovy.ast.expr.VariableExpression -import org.codehaus.groovy.ast.stmt.Statement -import org.codehaus.groovy.control.CompilationUnit -import org.codehaus.groovy.control.CompilePhase -import org.codehaus.groovy.control.CompilerConfiguration -import org.codehaus.groovy.control.ErrorCollector -import org.codehaus.groovy.control.Janitor -import org.codehaus.groovy.control.ProcessingUnit -import org.codehaus.groovy.control.SourceUnit -import org.codehaus.groovy.control.customizers.ImportCustomizer -import org.codehaus.groovy.control.io.ReaderSource -import org.codehaus.groovy.runtime.MethodClosure -import org.codehaus.groovy.syntax.SyntaxException -import org.codehaus.groovy.tools.Utilities - -import static org.codehaus.groovy.ast.tools.GeneralUtils.classX -import static org.codehaus.groovy.ast.tools.GeneralUtils.propX - -@GroovyASTTransformation(phase = CompilePhase.SEMANTIC_ANALYSIS) -class ASTTestTransformation extends AbstractASTTransformation implements CompilationUnitAware { - private CompilationUnit compilationUnit - - void visit(final ASTNode[] nodes, final SourceUnit source) { - AnnotationNode annotationNode = nodes[0] - def member = annotationNode.getMember('phase') - def phase = null - if (member) { - if (member instanceof VariableExpression) { - phase = CompilePhase.valueOf(member.text) - } else if (member instanceof PropertyExpression) { - phase = CompilePhase.valueOf(member.propertyAsString) - } - annotationNode.setMember('phase', propX(classX(ClassHelper.make(CompilePhase)), phase.toString())) - } - member = annotationNode.getMember('value') - if (member && !(member instanceof ClosureExpression)) { - throw new SyntaxException("ASTTest value must be a closure", member.getLineNumber(), member.getColumnNumber()) - } - if (!member && !annotationNode.getNodeMetaData(ASTTestTransformation)) { - throw new SyntaxException("Missing test expression", annotationNode.getLineNumber(), annotationNode.getColumnNumber()) - } - // convert value into node metadata so that the expression doesn't mix up with other AST xforms like type checking - annotationNode.putNodeMetaData(ASTTestTransformation, member) - annotationNode.getMembers().remove('value') - - def pcallback = compilationUnit.progressCallback - def callback = new CompilationUnit.ProgressCallback() { - Binding binding = new Binding([:].withDefault {null}) - - @Override - void call(final ProcessingUnit context, final int phaseRef) { - if (phase==null || phaseRef == phase.phaseNumber) { - ClosureExpression testClosure = nodes[0].getNodeMetaData(ASTTestTransformation) - StringBuilder sb = new StringBuilder() - for (int i = testClosure.lineNumber; i <= testClosure.lastLineNumber; i++) { - sb.append(source.source.getLine(i, new Janitor())).append('\n') - } - def testSource = sb.substring(testClosure.columnNumber + 1, sb.length()) - testSource = testSource.substring(0, testSource.lastIndexOf('}')) - CompilerConfiguration config = new CompilerConfiguration() - def customizer = new ImportCustomizer() - config.addCompilationCustomizers(customizer) - binding['sourceUnit'] = source - binding['node'] = nodes[1] - binding['lookup'] = new MethodClosure(LabelFinder, "lookup").curry(nodes[1]) - binding['compilationUnit'] = compilationUnit - binding['compilePhase'] = CompilePhase.fromPhaseNumber(phaseRef) - - GroovyShell shell = new GroovyShell(binding, config) - - source.AST.imports.each { - customizer.addImport(it.alias, it.type.name) - } - source.AST.starImports.each { - customizer.addStarImports(it.packageName) - } - source.AST.staticImports.each { - customizer.addStaticImport(it.value.alias, it.value.type.name, it.value.fieldName) - } - source.AST.staticStarImports.each { - customizer.addStaticStars(it.value.className) - } - shell.evaluate(testSource) - } - } - } - - if (pcallback!=null) { - if (pcallback instanceof ProgressCallbackChain) { - pcallback.addCallback(callback) - } else { - pcallback = new ProgressCallbackChain(pcallback, callback) - } - callback = pcallback - } - - compilationUnit.setProgressCallback(callback) - - } - - void setCompilationUnit(final CompilationUnit unit) { - this.compilationUnit = unit - } - - private static class AssertionSourceDelegatingSourceUnit extends SourceUnit { - private final ReaderSource delegate - - AssertionSourceDelegatingSourceUnit(final String name, final ReaderSource source, final CompilerConfiguration flags, final GroovyClassLoader loader, final ErrorCollector er) { - super(name, '', flags, loader, er) - delegate = source - } - - @Override - String getSample(final int line, final int column, final Janitor janitor) { - String sample = null; - String text = delegate.getLine(line, janitor); - - if (text != null) { - if (column > 0) { - String marker = Utilities.repeatString(" ", column - 1) + "^"; - - if (column > 40) { - int start = column - 30 - 1; - int end = (column + 10 > text.length() ? text.length() : column + 10 - 1); - sample = " " + text.substring(start, end) + Utilities.eol() + " " + - marker.substring(start, marker.length()); - } else { - sample = " " + text + Utilities.eol() + " " + marker; - } - } else { - sample = text; - } - } - - return sample; - - } - - } - - private static class ProgressCallbackChain extends CompilationUnit.ProgressCallback { - - private final List<CompilationUnit.ProgressCallback> chain = new LinkedList<CompilationUnit.ProgressCallback>() - - ProgressCallbackChain(CompilationUnit.ProgressCallback... callbacks) { - if (callbacks!=null) { - callbacks.each { addCallback(it) } - } - } - - public void addCallback(CompilationUnit.ProgressCallback callback) { - chain << callback - } - - @Override - void call(final ProcessingUnit context, final int phase) { - chain*.call(context, phase) - } - } - - public static class LabelFinder extends ClassCodeVisitorSupport { - - public static List<Statement> lookup(MethodNode node, String label) { - LabelFinder finder = new LabelFinder(label, null) - node.code.visit(finder) - - finder.targets - } - - public static List<Statement> lookup(ClassNode node, String label) { - LabelFinder finder = new LabelFinder(label, null) - node.methods*.code*.visit(finder) - node.declaredConstructors*.code*.visit(finder) - - finder.targets - } - - private final String label - private final SourceUnit unit - - private final List<Statement> targets = new LinkedList<Statement>(); - - LabelFinder(final String label, final SourceUnit unit) { - this.label = label - this.unit = unit; - } - - @Override - protected SourceUnit getSourceUnit() { - unit - } - - @Override - protected void visitStatement(final Statement statement) { - super.visitStatement(statement) - if (statement.statementLabel==label) targets << statement - } - - List<Statement> getTargets() { - return Collections.unmodifiableList(targets) - } - } - -} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/ConditionalInterruptibleASTTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/ConditionalInterruptibleASTTransformation.groovy b/src/main/groovy/org/codehaus/groovy/transform/ConditionalInterruptibleASTTransformation.groovy deleted file mode 100644 index 2cda121..0000000 --- a/src/main/groovy/org/codehaus/groovy/transform/ConditionalInterruptibleASTTransformation.groovy +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.codehaus.groovy.transform - -import groovy.transform.ConditionalInterrupt -import org.codehaus.groovy.ast.AnnotatedNode -import org.codehaus.groovy.ast.AnnotationNode -import org.codehaus.groovy.ast.ClassHelper -import org.codehaus.groovy.ast.ClassNode -import org.codehaus.groovy.ast.FieldNode -import org.codehaus.groovy.ast.MethodNode -import org.codehaus.groovy.ast.Parameter -import org.codehaus.groovy.ast.PropertyNode -import org.codehaus.groovy.ast.expr.ArgumentListExpression -import org.codehaus.groovy.ast.expr.ClosureExpression -import org.codehaus.groovy.ast.expr.Expression -import org.codehaus.groovy.ast.expr.MethodCallExpression -import org.codehaus.groovy.ast.expr.VariableExpression -import org.codehaus.groovy.ast.tools.ClosureUtils -import org.codehaus.groovy.control.CompilePhase - -/** - * Allows "interrupt-safe" executions of scripts by adding a custom conditional - * check on loops (for, while, do) and first statement of closures. By default, also adds an interrupt check - * statement on the beginning of method calls. - * - * @see groovy.transform.ConditionalInterrupt - * @author Cedric Champeau - * @author Hamlet D'Arcy - * @author Paul King - * @since 1.8.0 - */ -@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) -public class ConditionalInterruptibleASTTransformation extends AbstractInterruptibleASTTransformation { - - private static final ClassNode MY_TYPE = ClassHelper.make(ConditionalInterrupt) - - private ClosureExpression conditionNode - private String conditionMethod - private MethodCallExpression conditionCallExpression - private ClassNode currentClass - - protected ClassNode type() { - return MY_TYPE - } - - protected void setupTransform(AnnotationNode node) { - super.setupTransform(node) - def member = node.getMember("value") - if (!member || !(member instanceof ClosureExpression)) internalError("Expected closure value for annotation parameter 'value'. Found $member") - conditionNode = member; - conditionMethod = 'conditionalTransform' + node.hashCode() + '$condition' - conditionCallExpression = new MethodCallExpression(new VariableExpression('this'), conditionMethod, new ArgumentListExpression()) - } - - protected String getErrorMessage() { - 'Execution interrupted. The following condition failed: ' + convertClosureToSource(conditionNode) - } - - void visitClass(ClassNode type) { - currentClass = type - def method = type.addMethod(conditionMethod, ACC_PRIVATE | ACC_SYNTHETIC, ClassHelper.OBJECT_TYPE, Parameter.EMPTY_ARRAY, ClassNode.EMPTY_ARRAY, conditionNode.code) - method.synthetic = true - if (applyToAllMembers) { - super.visitClass(type) - } - } - - protected Expression createCondition() { - conditionCallExpression - } - - @Override - void visitAnnotations(AnnotatedNode node) { - // this transformation does not apply on annotation nodes - // visiting could lead to stack overflows - } - - @Override - void visitField(FieldNode node) { - if (!node.isStatic() && !node.isSynthetic()) { - super.visitField node - } - } - - @Override - void visitProperty(PropertyNode node) { - if (!node.isStatic() && !node.isSynthetic()) { - super.visitProperty node - } - } - - @Override - void visitClosureExpression(ClosureExpression closureExpr) { - if (closureExpr == conditionNode) return // do not visit the closure from the annotation itself - def code = closureExpr.code - closureExpr.code = wrapBlock(code) - super.visitClosureExpression closureExpr - } - - @Override - void visitMethod(MethodNode node) { - if (node.name == conditionMethod && !node.isSynthetic()) return // do not visit the generated method - if (node.name == 'run' && currentClass.isScript() && node.parameters.length == 0) { - // the run() method should not have the statement added, otherwise the script binding won't be set before - // the condition is actually tested - super.visitMethod(node) - } else { - if (checkOnMethodStart && !node.isSynthetic() && !node.isStatic() && !node.isAbstract()) { - def code = node.code - node.code = wrapBlock(code); - } - if (!node.isSynthetic() && !node.isStatic()) super.visitMethod(node) - } - } - - /** - * Converts a ClosureExpression into the String source. - * @param expression a closure - * @return the source the closure was created from - */ - private String convertClosureToSource(ClosureExpression expression) { - try { - return ClosureUtils.convertClosureToSource(this.source.source, expression); - } catch(Exception e) { - return e.message - } - } -} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/ThreadInterruptibleASTTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/ThreadInterruptibleASTTransformation.groovy b/src/main/groovy/org/codehaus/groovy/transform/ThreadInterruptibleASTTransformation.groovy deleted file mode 100644 index a4fb4c3..0000000 --- a/src/main/groovy/org/codehaus/groovy/transform/ThreadInterruptibleASTTransformation.groovy +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.codehaus.groovy.transform - -import groovy.transform.CompileStatic -import groovy.transform.ThreadInterrupt -import org.codehaus.groovy.ast.ClassHelper -import org.codehaus.groovy.ast.ClassNode -import org.codehaus.groovy.ast.MethodNode -import org.codehaus.groovy.ast.Parameter -import org.codehaus.groovy.ast.expr.ArgumentListExpression -import org.codehaus.groovy.ast.expr.ClassExpression -import org.codehaus.groovy.ast.expr.ClosureExpression -import org.codehaus.groovy.ast.expr.Expression -import org.codehaus.groovy.ast.expr.MethodCallExpression -import org.codehaus.groovy.control.CompilePhase - -/** - * Allows "interrupt-safe" executions of scripts by adding Thread.currentThread().isInterrupted() - * checks on loops (for, while, do) and first statement of closures. By default, also adds an interrupt check - * statement on the beginning of method calls. - * - * @see groovy.transform.ThreadInterrupt - * - * @author Cedric Champeau - * @author Hamlet D'Arcy - * - * @since 1.8.0 - */ -@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) -@CompileStatic -public class ThreadInterruptibleASTTransformation extends AbstractInterruptibleASTTransformation { - - private static final ClassNode MY_TYPE = ClassHelper.make(ThreadInterrupt) - private static final ClassNode THREAD_TYPE = ClassHelper.make(Thread) - private static final MethodNode CURRENTTHREAD_METHOD - private static final MethodNode ISINTERRUPTED_METHOD - - static { - CURRENTTHREAD_METHOD = THREAD_TYPE.getMethod('currentThread', Parameter.EMPTY_ARRAY) - ISINTERRUPTED_METHOD = THREAD_TYPE.getMethod('isInterrupted', Parameter.EMPTY_ARRAY) - } - - protected ClassNode type() { - return MY_TYPE; - } - - protected String getErrorMessage() { - 'Execution interrupted. The current thread has been interrupted.' - } - - protected Expression createCondition() { - def currentThread = new MethodCallExpression(new ClassExpression(THREAD_TYPE), - 'currentThread', - ArgumentListExpression.EMPTY_ARGUMENTS) - currentThread.methodTarget = CURRENTTHREAD_METHOD - def isInterrupted = new MethodCallExpression( - currentThread, - 'isInterrupted', ArgumentListExpression.EMPTY_ARGUMENTS) - isInterrupted.methodTarget = ISINTERRUPTED_METHOD - [currentThread, isInterrupted]*.implicitThis = false - - isInterrupted - } - - - @Override - public void visitClosureExpression(ClosureExpression closureExpr) { - def code = closureExpr.code - closureExpr.code = wrapBlock(code) - super.visitClosureExpression closureExpr - } - - @Override - public void visitMethod(MethodNode node) { - if (checkOnMethodStart && !node.isSynthetic() && !node.isAbstract()) { - def code = node.code - node.code = wrapBlock(code); - } - super.visitMethod(node) - } -} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy b/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy deleted file mode 100644 index fbc923b..0000000 --- a/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy +++ /dev/null @@ -1,321 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.codehaus.groovy.transform - -import groovy.transform.TimedInterrupt -import org.codehaus.groovy.ast.ASTNode -import org.codehaus.groovy.ast.AnnotatedNode -import org.codehaus.groovy.ast.AnnotationNode -import org.codehaus.groovy.ast.ClassCodeVisitorSupport -import org.codehaus.groovy.ast.ClassHelper -import org.codehaus.groovy.ast.ClassNode -import org.codehaus.groovy.ast.FieldNode -import org.codehaus.groovy.ast.MethodNode -import org.codehaus.groovy.ast.PropertyNode -import org.codehaus.groovy.ast.expr.ClosureExpression -import org.codehaus.groovy.ast.expr.ConstantExpression -import org.codehaus.groovy.ast.expr.DeclarationExpression -import org.codehaus.groovy.ast.expr.Expression -import org.codehaus.groovy.ast.stmt.BlockStatement -import org.codehaus.groovy.ast.stmt.DoWhileStatement -import org.codehaus.groovy.ast.stmt.ForStatement -import org.codehaus.groovy.ast.stmt.WhileStatement -import org.codehaus.groovy.control.CompilePhase -import org.codehaus.groovy.control.SourceUnit - -import java.util.concurrent.TimeUnit -import java.util.concurrent.TimeoutException - -import static org.codehaus.groovy.ast.ClassHelper.make -import static org.codehaus.groovy.ast.tools.GeneralUtils.args -import static org.codehaus.groovy.ast.tools.GeneralUtils.callX -import static org.codehaus.groovy.ast.tools.GeneralUtils.classX -import static org.codehaus.groovy.ast.tools.GeneralUtils.constX -import static org.codehaus.groovy.ast.tools.GeneralUtils.ctorX -import static org.codehaus.groovy.ast.tools.GeneralUtils.ifS -import static org.codehaus.groovy.ast.tools.GeneralUtils.ltX -import static org.codehaus.groovy.ast.tools.GeneralUtils.plusX -import static org.codehaus.groovy.ast.tools.GeneralUtils.propX -import static org.codehaus.groovy.ast.tools.GeneralUtils.throwS -import static org.codehaus.groovy.ast.tools.GeneralUtils.varX - -/** - * Allows "interrupt-safe" executions of scripts by adding timer expiration - * checks on loops (for, while, do) and first statement of closures. By default, - * also adds an interrupt check statement on the beginning of method calls. - * - * @author Cedric Champeau - * @author Hamlet D'Arcy - * @author Paul King - * @see groovy.transform.ThreadInterrupt - * @since 1.8.0 - */ -@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION) -public class TimedInterruptibleASTTransformation extends AbstractASTTransformation { - - private static final ClassNode MY_TYPE = make(TimedInterrupt) - private static final String CHECK_METHOD_START_MEMBER = 'checkOnMethodStart' - private static final String APPLY_TO_ALL_CLASSES = 'applyToAllClasses' - private static final String APPLY_TO_ALL_MEMBERS = 'applyToAllMembers' - private static final String THROWN_EXCEPTION_TYPE = "thrown" - - public void visit(ASTNode[] nodes, SourceUnit source) { - init(nodes, source); - AnnotationNode node = nodes[0] - AnnotatedNode annotatedNode = nodes[1] - if (!MY_TYPE.equals(node.getClassNode())) { - internalError("Transformation called from wrong annotation: $node.classNode.name") - } - - def checkOnMethodStart = getConstantAnnotationParameter(node, CHECK_METHOD_START_MEMBER, Boolean.TYPE, true) - def applyToAllMembers = getConstantAnnotationParameter(node, APPLY_TO_ALL_MEMBERS, Boolean.TYPE, true) - def applyToAllClasses = applyToAllMembers ? getConstantAnnotationParameter(node, APPLY_TO_ALL_CLASSES, Boolean.TYPE, true) : false - def maximum = getConstantAnnotationParameter(node, 'value', Long.TYPE, Long.MAX_VALUE) - def thrown = AbstractInterruptibleASTTransformation.getClassAnnotationParameter(node, THROWN_EXCEPTION_TYPE, make(TimeoutException)) - - Expression unit = node.getMember('unit') ?: propX(classX(TimeUnit), "SECONDS") - - // should be limited to the current SourceUnit or propagated to the whole CompilationUnit - // DO NOT inline visitor creation in code below. It has state that must not persist between calls - if (applyToAllClasses) { - // guard every class and method defined in this script - source.getAST()?.classes?.each { ClassNode it -> - def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) - visitor.visitClass(it) - } - } else if (annotatedNode instanceof ClassNode) { - // only guard this particular class - def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) - visitor.visitClass annotatedNode - } else if (!applyToAllMembers && annotatedNode instanceof MethodNode) { - // only guard this particular method (plus initCode for class) - def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) - visitor.visitMethod annotatedNode - visitor.visitClass annotatedNode.declaringClass - } else if (!applyToAllMembers && annotatedNode instanceof FieldNode) { - // only guard this particular field (plus initCode for class) - def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) - visitor.visitField annotatedNode - visitor.visitClass annotatedNode.declaringClass - } else if (!applyToAllMembers && annotatedNode instanceof DeclarationExpression) { - // only guard this particular declaration (plus initCode for class) - def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) - visitor.visitDeclarationExpression annotatedNode - visitor.visitClass annotatedNode.declaringClass - } else { - // only guard the script class - source.getAST()?.classes?.each { ClassNode it -> - if (it.isScript()) { - def visitor = new TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, node.hashCode()) - visitor.visitClass(it) - } - } - } - } - - static def getConstantAnnotationParameter(AnnotationNode node, String parameterName, Class type, defaultValue) { - def member = node.getMember(parameterName) - if (member) { - if (member instanceof ConstantExpression) { - // TODO not sure this try offers value - testing Groovy annotation type handing - throw GroovyBugError or remove? - try { - return member.value.asType(type) - } catch (ignore) { - internalError("Expecting boolean value for ${parameterName} annotation parameter. Found $member") - } - } else { - internalError("Expecting boolean value for ${parameterName} annotation parameter. Found $member") - } - } - return defaultValue - } - - private static void internalError(String message) { - throw new RuntimeException("Internal error: $message") - } - - private static class TimedInterruptionVisitor extends ClassCodeVisitorSupport { - final private SourceUnit source - final private boolean checkOnMethodStart - final private boolean applyToAllClasses - final private boolean applyToAllMembers - private FieldNode expireTimeField = null - private FieldNode startTimeField = null - private final Expression unit - private final maximum - private final ClassNode thrown - private final String basename - - TimedInterruptionVisitor(source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, unit, thrown, hash) { - this.source = source - this.checkOnMethodStart = checkOnMethodStart - this.applyToAllClasses = applyToAllClasses - this.applyToAllMembers = applyToAllMembers - this.unit = unit - this.maximum = maximum - this.thrown = thrown - this.basename = 'timedInterrupt' + hash - } - - /** - * @return Returns the interruption check statement. - */ - final createInterruptStatement() { - ifS( - - ltX( - propX(varX("this"), basename + '$expireTime'), - callX(make(System), 'nanoTime') - ), - throwS( - ctorX(thrown, - args( - plusX( - plusX( - constX('Execution timed out after ' + maximum + ' '), - callX(callX(unit, 'name'), 'toLowerCase', propX(classX(Locale), 'US')) - ), - plusX( - constX('. Start time: '), - propX(varX("this"), basename + '$startTime') - ) - ) - - ) - ) - ) - ) - } - - /** - * Takes a statement and wraps it into a block statement which first element is the interruption check statement. - * @param statement the statement to be wrapped - * @return a {@link BlockStatement block statement} which first element is for checking interruption, and the - * second one the statement to be wrapped. - */ - private wrapBlock(statement) { - def stmt = new BlockStatement(); - stmt.addStatement(createInterruptStatement()); - stmt.addStatement(statement); - stmt - } - - @Override - void visitClass(ClassNode node) { - if (node.getDeclaredField(basename + '$expireTime')) { - return - } - expireTimeField = node.addField(basename + '$expireTime', - ACC_FINAL | ACC_PRIVATE, - ClassHelper.long_TYPE, - plusX( - callX(make(System), 'nanoTime'), - callX( - propX(classX(TimeUnit), 'NANOSECONDS'), - 'convert', - args(constX(maximum, true), unit) - ) - ) - ); - expireTimeField.synthetic = true - startTimeField = node.addField(basename + '$startTime', - ACC_FINAL | ACC_PRIVATE, - make(Date), - ctorX(make(Date)) - ) - startTimeField.synthetic = true - - // force these fields to be initialized first - node.fields.remove(expireTimeField) - node.fields.remove(startTimeField) - node.fields.add(0, startTimeField) - node.fields.add(0, expireTimeField) - if (applyToAllMembers) { - super.visitClass node - } - } - - @Override - void visitClosureExpression(ClosureExpression closureExpr) { - def code = closureExpr.code - if (code instanceof BlockStatement) { - code.statements.add(0, createInterruptStatement()) - } else { - closureExpr.code = wrapBlock(code) - } - super.visitClosureExpression closureExpr - } - - @Override - void visitField(FieldNode node) { - if (!node.isStatic() && !node.isSynthetic()) { - super.visitField node - } - } - - @Override - void visitProperty(PropertyNode node) { - if (!node.isStatic() && !node.isSynthetic()) { - super.visitProperty node - } - } - - /** - * Shortcut method which avoids duplicating code for every type of loop. - * Actually wraps the loopBlock of different types of loop statements. - */ - private visitLoop(loopStatement) { - def statement = loopStatement.loopBlock - loopStatement.loopBlock = wrapBlock(statement) - } - - @Override - void visitForLoop(ForStatement forStatement) { - visitLoop(forStatement) - super.visitForLoop(forStatement) - } - - @Override - void visitDoWhileLoop(final DoWhileStatement doWhileStatement) { - visitLoop(doWhileStatement) - super.visitDoWhileLoop(doWhileStatement) - } - - @Override - void visitWhileLoop(final WhileStatement whileStatement) { - visitLoop(whileStatement) - super.visitWhileLoop(whileStatement) - } - - @Override - void visitMethod(MethodNode node) { - if (checkOnMethodStart && !node.isSynthetic() && !node.isStatic() && !node.isAbstract()) { - def code = node.code - node.code = wrapBlock(code); - } - if (!node.isSynthetic() && !node.isStatic()) { - super.visitMethod(node) - } - } - - protected SourceUnit getSourceUnit() { - return source; - } - } -} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/tailrec/AstHelper.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/AstHelper.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/AstHelper.groovy new file mode 100644 index 0000000..206b0df --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/AstHelper.groovy @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.VariableExpression +import org.codehaus.groovy.ast.stmt.ContinueStatement +import org.codehaus.groovy.ast.stmt.ExpressionStatement +import org.codehaus.groovy.ast.stmt.Statement +import org.codehaus.groovy.ast.stmt.ThrowStatement + +import java.lang.reflect.Modifier + +import static org.codehaus.groovy.ast.tools.GeneralUtils.classX +import static org.codehaus.groovy.ast.tools.GeneralUtils.declS +import static org.codehaus.groovy.ast.tools.GeneralUtils.propX +import static org.codehaus.groovy.ast.tools.GeneralUtils.varX + +/** + * Helping to create a few standard AST constructs + * + * @author Johannes Link + */ +@CompileStatic +class AstHelper { + static ExpressionStatement createVariableDefinition(String variableName, ClassNode variableType, Expression value, boolean variableShouldBeFinal = false ) { + def newVariable = varX(variableName, variableType) + if (variableShouldBeFinal) + newVariable.setModifiers(Modifier.FINAL) + (ExpressionStatement) declS(newVariable, value) + } + + static ExpressionStatement createVariableAlias(String aliasName, ClassNode variableType, String variableName ) { + createVariableDefinition(aliasName, variableType, varX(variableName, variableType)) + } + + static VariableExpression createVariableReference(Map variableSpec) { + varX((String) variableSpec.name, (ClassNode) variableSpec.type) + } + + /** + * This statement should make the code jump to surrounding while loop's start label + * Does not work from within Closures + */ + static Statement recurStatement() { + //continue _RECUR_HERE_ + new ContinueStatement(InWhileLoopWrapper.LOOP_LABEL) + } + + /** + * This statement will throw exception which will be caught and redirected to jump to surrounding while loop's start label + * Also works from within Closures but is a tiny bit slower + */ + static Statement recurByThrowStatement() { + // throw InWhileLoopWrapper.LOOP_EXCEPTION + new ThrowStatement(propX(classX(InWhileLoopWrapper), 'LOOP_EXCEPTION')) + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/tailrec/CollectRecursiveCalls.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/CollectRecursiveCalls.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/CollectRecursiveCalls.groovy new file mode 100644 index 0000000..2c7e6de --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/CollectRecursiveCalls.groovy @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression; + +/** + * Collect all recursive calls within method + * + * @author Johannes Link + */ +@CompileStatic +class CollectRecursiveCalls extends CodeVisitorSupport { + MethodNode method + List<Expression> recursiveCalls = [] + + public void visitMethodCallExpression(MethodCallExpression call) { + if (isRecursive(call)) { + recursiveCalls << call + } + super.visitMethodCallExpression(call) + } + + public void visitStaticMethodCallExpression(StaticMethodCallExpression call) { + if (isRecursive(call)) { + recursiveCalls << call + } + super.visitStaticMethodCallExpression(call) + } + + private boolean isRecursive(call) { + new RecursivenessTester().isRecursive(method: method, call: call) + } + + synchronized List<Expression> collect(MethodNode method) { + recursiveCalls.clear() + this.method = method + this.method.code.visit(this) + recursiveCalls + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/tailrec/HasRecursiveCalls.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/HasRecursiveCalls.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/HasRecursiveCalls.groovy new file mode 100644 index 0000000..79f8e6d --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/HasRecursiveCalls.groovy @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression + +/** + * + * Check if there are any recursive calls in a method + * + * @author Johannes Link + */ +@CompileStatic +class HasRecursiveCalls extends CodeVisitorSupport { + MethodNode method + boolean hasRecursiveCalls = false + + public void visitMethodCallExpression(MethodCallExpression call) { + if (isRecursive(call)) { + hasRecursiveCalls = true + } else { + super.visitMethodCallExpression(call) + } + } + + public void visitStaticMethodCallExpression(StaticMethodCallExpression call) { + if (isRecursive(call)) { + hasRecursiveCalls = true + } else { + super.visitStaticMethodCallExpression(call) + } + } + + private boolean isRecursive(call) { + new RecursivenessTester().isRecursive(method: method, call: call) + } + + synchronized boolean test(MethodNode method) { + hasRecursiveCalls = false + this.method = method + this.method.code.visit(this) + hasRecursiveCalls + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/tailrec/InWhileLoopWrapper.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/InWhileLoopWrapper.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/InWhileLoopWrapper.groovy new file mode 100644 index 0000000..981f146 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/InWhileLoopWrapper.groovy @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.VariableScope +import org.codehaus.groovy.ast.expr.BooleanExpression +import org.codehaus.groovy.ast.expr.ConstantExpression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.CatchStatement +import org.codehaus.groovy.ast.stmt.ContinueStatement +import org.codehaus.groovy.ast.stmt.EmptyStatement +import org.codehaus.groovy.ast.stmt.Statement +import org.codehaus.groovy.ast.stmt.TryCatchStatement +import org.codehaus.groovy.ast.stmt.WhileStatement + +/** + * Wrap the body of a method in a while loop, nested in a try-catch. + * This is the first step in making a tail recursive method iterative. + * + * There are two ways to invoke the next iteration step: + * 1. "continue _RECURE_HERE_" is used by recursive calls outside of closures + * 2. "throw LOOP_EXCEPTION" is used by recursive calls within closures b/c you cannot invoke "continue" from there + * + * @author Johannes Link + */ +@CompileStatic +class InWhileLoopWrapper { + + static final String LOOP_LABEL = '_RECUR_HERE_' + static final GotoRecurHereException LOOP_EXCEPTION = new GotoRecurHereException() + + void wrap(MethodNode method) { + BlockStatement oldBody = method.code as BlockStatement + TryCatchStatement tryCatchStatement = new TryCatchStatement( + oldBody, + EmptyStatement.INSTANCE + ) + tryCatchStatement.addCatch(new CatchStatement( + new Parameter(ClassHelper.make(GotoRecurHereException), 'ignore'), + new ContinueStatement(InWhileLoopWrapper.LOOP_LABEL) + )) + + WhileStatement whileLoop = new WhileStatement( + new BooleanExpression(new ConstantExpression(true)), + new BlockStatement([tryCatchStatement] as List<Statement>, new VariableScope(method.variableScope)) + ) + List<Statement> whileLoopStatements = ((BlockStatement) whileLoop.loopBlock).statements + if (whileLoopStatements.size() > 0) + whileLoopStatements[0].statementLabel = LOOP_LABEL + BlockStatement newBody = new BlockStatement([] as List<Statement>, new VariableScope(method.variableScope)) + newBody.addStatement(whileLoop) + method.code = newBody + } +} + +/** + * Exception will be thrown by recursive calls in closures and caught in while loop to continue to LOOP_LABEL + */ +class GotoRecurHereException extends Throwable { + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/tailrec/RecursivenessTester.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/RecursivenessTester.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/RecursivenessTester.groovy new file mode 100644 index 0000000..7c9545a --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/RecursivenessTester.groovy @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.expr.ConstantExpression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression +import org.codehaus.groovy.ast.expr.VariableExpression + +/** + * + * Test if a method call is recursive if called within a given method node. + * Handles static calls as well. + * + * Currently known simplifications: + * - Does not check for method overloading or overridden methods. + * - Does not check for matching return types; even void and any object type are considered to be compatible. + * - Argument type matching could be more specific in case of static compilation. + * - Method names via a GString are never considered to be recursive + * + * @author Johannes Link + */ +class RecursivenessTester { + public boolean isRecursive(params) { + assert params.method.class == MethodNode + assert params.call.class == MethodCallExpression || StaticMethodCallExpression + + isRecursive(params.method, params.call) + } + + public boolean isRecursive(MethodNode method, MethodCallExpression call) { + if (!isCallToThis(call)) + return false + // Could be a GStringExpression + if (! (call.method instanceof ConstantExpression)) + return false + if (call.method.value != method.name) + return false + methodParamsMatchCallArgs(method, call) + } + + public boolean isRecursive(MethodNode method, StaticMethodCallExpression call) { + if (!method.isStatic()) + return false + if (method.declaringClass != call.ownerType) + return false + if (call.method != method.name) + return false + methodParamsMatchCallArgs(method, call) + } + + private boolean isCallToThis(MethodCallExpression call) { + if (call.objectExpression == null) + return call.isImplicitThis() + if (! (call.objectExpression instanceof VariableExpression)) { + return false + } + return call.objectExpression.isThisExpression() + } + + private boolean methodParamsMatchCallArgs(method, call) { + if (method.parameters.size() != call.arguments.expressions.size()) + return false + def classNodePairs = [method.parameters*.type, call.arguments*.type].transpose() + return classNodePairs.every { ClassNode paramType, ClassNode argType -> + return areTypesCallCompatible(argType, paramType) + } + } + + /** + * Parameter type and calling argument type can both be derived from the other since typing information is + * optional in Groovy. + * Since int is not derived from Integer (nor the other way around) we compare the boxed types + */ + private areTypesCallCompatible(ClassNode argType, ClassNode paramType) { + ClassNode boxedArg = ClassHelper.getWrapper(argType) + ClassNode boxedParam = ClassHelper.getWrapper(paramType) + return boxedArg.isDerivedFrom(boxedParam) || boxedParam.isDerivedFrom(boxedArg) + } + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnAdderForClosures.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnAdderForClosures.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnAdderForClosures.groovy new file mode 100644 index 0000000..64ebce7 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnAdderForClosures.groovy @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import org.codehaus.groovy.ast.ClassHelper +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.MethodNode +import org.codehaus.groovy.ast.Parameter +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.classgen.ReturnAdder + +/** + * Adds explicit return statements to implicit return points in a closure. This is necessary since + * tail-recursion is detected by having the recursive call within the return statement. + * + * @author Johannes Link + */ +class ReturnAdderForClosures extends CodeVisitorSupport { + + synchronized void visitMethod(MethodNode method) { + method.code.visit(this) + } + + public void visitClosureExpression(ClosureExpression expression) { + //Create a dummy method with the closure's code as the method's code. Then user ReturnAdder, which only works for methods. + MethodNode node = new MethodNode("dummy", 0, ClassHelper.OBJECT_TYPE, Parameter.EMPTY_ARRAY, ClassNode.EMPTY_ARRAY, expression.code); + new ReturnAdder().visitMethod(node); + super.visitClosureExpression(expression) + } + +} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnStatementToIterationConverter.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnStatementToIterationConverter.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnStatementToIterationConverter.groovy new file mode 100644 index 0000000..2c75f4f --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/ReturnStatementToIterationConverter.groovy @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ClassNode +import org.codehaus.groovy.ast.expr.BinaryExpression +import org.codehaus.groovy.ast.expr.Expression +import org.codehaus.groovy.ast.expr.MethodCallExpression +import org.codehaus.groovy.ast.expr.StaticMethodCallExpression +import org.codehaus.groovy.ast.expr.TupleExpression +import org.codehaus.groovy.ast.expr.VariableExpression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.ExpressionStatement +import org.codehaus.groovy.ast.stmt.ReturnStatement +import org.codehaus.groovy.ast.stmt.Statement + +import static org.codehaus.groovy.ast.tools.GeneralUtils.assignS +import static org.codehaus.groovy.ast.tools.GeneralUtils.varX + +/** + * Translates all return statements into an invocation of the next iteration. This can be either + * - "continue LOOP_LABEL": Outside closures + * - "throw LOOP_EXCEPTION": Inside closures + * + * Moreover, before adding the recur statement the iteration parameters (originally the method args) + * are set to their new value. To prevent variable aliasing parameters will be copied into temp vars + * before they are changes so that their current iteration value can be used when setting other params. + * + * There's probably place for optimizing the amount of variable copying being done, e.g. + * parameters that are only handed through must not be copied at all. + * + * @author Johannes Link + */ +@CompileStatic +class ReturnStatementToIterationConverter { + + Statement recurStatement = AstHelper.recurStatement() + + Statement convert(ReturnStatement statement, Map<Integer, Map> positionMapping) { + Expression recursiveCall = statement.expression + if (!isAMethodCalls(recursiveCall)) + return statement + + Map<String, Map> tempMapping = [:] + Map tempDeclarations = [:] + List<ExpressionStatement> argAssignments = [] + + BlockStatement result = new BlockStatement() + result.statementLabel = statement.statementLabel + + /* Create temp declarations for all method arguments. + * Add the declarations and var mapping to tempMapping and tempDeclarations for further reference. + */ + getArguments(recursiveCall).eachWithIndex { Expression expression, int index -> + ExpressionStatement tempDeclaration = createTempDeclaration(index, positionMapping, tempMapping, tempDeclarations) + result.addStatement(tempDeclaration) + } + + /* + * Assign the iteration variables their new value before recuring + */ + getArguments(recursiveCall).eachWithIndex { Expression expression, int index -> + ExpressionStatement argAssignment = createAssignmentToIterationVariable(expression, index, positionMapping) + argAssignments.add(argAssignment) + result.addStatement(argAssignment) + } + + Set<String> unusedTemps = replaceAllArgUsages(argAssignments, tempMapping) + for (String temp : unusedTemps) { + result.statements.remove(tempDeclarations[temp]) + } + result.addStatement(recurStatement) + + return result + } + + private ExpressionStatement createAssignmentToIterationVariable(Expression expression, int index, Map<Integer, Map> positionMapping) { + String argName = positionMapping[index]['name'] + ClassNode argAndTempType = positionMapping[index]['type'] as ClassNode + ExpressionStatement argAssignment = (ExpressionStatement) assignS(varX(argName, argAndTempType), expression) + argAssignment + } + + private ExpressionStatement createTempDeclaration(int index, Map<Integer, Map> positionMapping, Map<String, Map> tempMapping, Map tempDeclarations) { + String argName = positionMapping[index]['name'] + String tempName = "_${argName}_" + ClassNode argAndTempType = positionMapping[index]['type'] as ClassNode + ExpressionStatement tempDeclaration = AstHelper.createVariableAlias(tempName, argAndTempType, argName) + tempMapping[argName] = [name: tempName, type: argAndTempType] + tempDeclarations[tempName] = tempDeclaration + return tempDeclaration + } + + private List<Expression> getArguments(Expression recursiveCall) { + if (recursiveCall instanceof MethodCallExpression) + return ((TupleExpression) ((MethodCallExpression) recursiveCall).arguments).expressions + if (recursiveCall instanceof StaticMethodCallExpression) + return ((TupleExpression) ((StaticMethodCallExpression) recursiveCall).arguments).expressions + } + + private boolean isAMethodCalls(Expression expression) { + expression.class in [MethodCallExpression, StaticMethodCallExpression] + } + + private Set<String> replaceAllArgUsages(List<ExpressionStatement> iterationVariablesAssignmentNodes, Map<String, Map> tempMapping) { + Set<String> unusedTempNames = tempMapping.values().collect {Map nameAndType -> (String) nameAndType['name']} as Set<String> + VariableReplacedListener tracker = new UsedVariableTracker() + for (ExpressionStatement statement : iterationVariablesAssignmentNodes) { + replaceArgUsageByTempUsage((BinaryExpression) statement.expression, tempMapping, tracker) + } + unusedTempNames = unusedTempNames - tracker.usedVariableNames + return unusedTempNames + } + + private void replaceArgUsageByTempUsage(BinaryExpression binary, Map tempMapping, UsedVariableTracker tracker) { + VariableAccessReplacer replacer = new VariableAccessReplacer(nameAndTypeMapping: tempMapping, listener: tracker) + // Replacement must only happen in binary.rightExpression. It's a hack in VariableExpressionReplacer which takes care of that. + replacer.replaceIn(binary) + } +} + +@CompileStatic +class UsedVariableTracker implements VariableReplacedListener { + + final Set<String> usedVariableNames = [] as Set + + @Override + void variableReplaced(VariableExpression oldVar, VariableExpression newVar) { + usedVariableNames.add(newVar.name) + } +} http://git-wip-us.apache.org/repos/asf/groovy/blob/10110145/src/main/groovy/org/codehaus/groovy/transform/tailrec/StatementReplacer.groovy ---------------------------------------------------------------------- diff --git a/src/main/groovy/org/codehaus/groovy/transform/tailrec/StatementReplacer.groovy b/src/main/groovy/org/codehaus/groovy/transform/tailrec/StatementReplacer.groovy new file mode 100644 index 0000000..3a9dab3 --- /dev/null +++ b/src/main/groovy/org/codehaus/groovy/transform/tailrec/StatementReplacer.groovy @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.codehaus.groovy.transform.tailrec + +import groovy.transform.CompileStatic +import org.codehaus.groovy.ast.ASTNode +import org.codehaus.groovy.ast.CodeVisitorSupport +import org.codehaus.groovy.ast.expr.ClosureExpression +import org.codehaus.groovy.ast.stmt.BlockStatement +import org.codehaus.groovy.ast.stmt.DoWhileStatement +import org.codehaus.groovy.ast.stmt.ForStatement +import org.codehaus.groovy.ast.stmt.IfStatement +import org.codehaus.groovy.ast.stmt.Statement +import org.codehaus.groovy.ast.stmt.WhileStatement + +/** + * Tool for replacing Statement objects in an AST by other Statement instances. + * + * Within @TailRecursive it is used to swap ReturnStatements with looping back to RECUR label + * + * @author Johannes Link + */ +@CompileStatic +class StatementReplacer extends CodeVisitorSupport { + + Closure<Boolean> when = { Statement node -> false } + Closure<Statement> replaceWith = { Statement statement -> statement } + int closureLevel = 0 + + void replaceIn(ASTNode root) { + root.visit(this) + } + + public void visitClosureExpression(ClosureExpression expression) { + closureLevel++ + try { + super.visitClosureExpression(expression) + } finally { + closureLevel-- + } + } + + public void visitBlockStatement(BlockStatement block) { + List<Statement> copyOfStatements = new ArrayList<Statement>(block.statements) + copyOfStatements.eachWithIndex { Statement statement, int index -> + replaceIfNecessary(statement) { Statement node -> block.statements[index] = node } + } + super.visitBlockStatement(block); + } + + public void visitIfElse(IfStatement ifElse) { + replaceIfNecessary(ifElse.ifBlock) { Statement s -> ifElse.ifBlock = s } + replaceIfNecessary(ifElse.elseBlock) { Statement s -> ifElse.elseBlock = s } + super.visitIfElse(ifElse); + } + + public void visitForLoop(ForStatement forLoop) { + replaceIfNecessary(forLoop.loopBlock) { Statement s -> forLoop.loopBlock = s } + super.visitForLoop(forLoop); + } + + public void visitWhileLoop(WhileStatement loop) { + replaceIfNecessary(loop.loopBlock) { Statement s -> loop.loopBlock = s } + super.visitWhileLoop(loop); + } + + public void visitDoWhileLoop(DoWhileStatement loop) { + replaceIfNecessary(loop.loopBlock) { Statement s -> loop.loopBlock = s } + super.visitDoWhileLoop(loop); + } + + + private void replaceIfNecessary(Statement nodeToCheck, Closure replacementCode) { + if (conditionFulfilled(nodeToCheck)) { + ASTNode replacement = replaceWith(nodeToCheck) + replacement.setSourcePosition(nodeToCheck); + replacement.copyNodeMetaData(nodeToCheck); + replacementCode(replacement) + } + } + + private boolean conditionFulfilled(ASTNode nodeToCheck) { + if (when.maximumNumberOfParameters < 2) + return when(nodeToCheck) + else + return when(nodeToCheck, isInClosure()) + } + + private boolean isInClosure() { + closureLevel > 0 + } + +}
