Repository: systemml Updated Branches: refs/heads/master 11c67055a -> ef1945d70
[SYSTEMML-540] Allow user to generate an inlined DML script in Caffe2DML - The inlining code is generic enough to be extended to perform parser-level inlining. This commit allows us to compare the tradeoffs of performing script-level inlining v/s hop-level inlining. - Refactored DMLParserWrapper and also added javadoc. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ef1945d7 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ef1945d7 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ef1945d7 Branch: refs/heads/master Commit: ef1945d70a85df4f646c315d06a1a094dad6ebb2 Parents: 11c6705 Author: Niketan Pansare <npan...@us.ibm.com> Authored: Thu Oct 11 14:28:59 2018 -0700 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Thu Oct 11 14:32:35 2018 -0700 ---------------------------------------------------------------------- .../parser/common/CustomErrorListener.java | 8 + .../sysml/parser/dml/DMLParserWrapper.java | 130 ++- .../java/org/apache/sysml/parser/dml/Dml.g4 | 6 +- .../sysml/parser/dml/DmlPreprocessor.java | 13 +- .../apache/sysml/parser/dml/InlineHelper.java | 798 +++++++++++++++++++ .../sysml/parser/dml/InlineableMethods.java | 98 +++ .../controlprogram/caching/CacheableData.java | 9 +- .../gpu/context/GPUMemoryManager.java | 2 +- src/main/python/systemml/mllearn/estimators.py | 5 +- .../org/apache/sysml/api/dl/Caffe2DML.scala | 8 +- .../org/apache/sysml/api/dl/CaffeLayer.scala | 12 +- .../org/apache/sysml/api/dl/CaffeSolver.scala | 35 +- .../org/apache/sysml/api/dl/DMLGenerator.scala | 36 +- .../scala/org/apache/sysml/api/dl/Utils.scala | 58 +- 14 files changed, 1114 insertions(+), 104 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java b/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java index 2af5f69..b82afc9 100644 --- a/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java +++ b/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java @@ -22,6 +22,7 @@ package org.apache.sysml.parser.common; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Set; import org.antlr.v4.runtime.BaseErrorListener; import org.antlr.v4.runtime.RecognitionException; @@ -38,6 +39,9 @@ public class CustomErrorListener extends BaseErrorListener { private boolean atLeastOneError = false; private boolean atLeastOneWarning = false; private String currentFileName = null; + + // Names of user internal and external functions definitions + public Set<String> functions; /** * List of parse issues. @@ -55,6 +59,10 @@ public class CustomErrorListener extends BaseErrorListener { public void unsetCurrentFileName() { currentFileName = null; } + + public Set<String> getFunctionDefs() { + return functions; + } /** * Validation error occurred. Add the error to the list of parse issues. http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java index 1d7daa1..9b3f8c4 100644 --- a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java +++ b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java @@ -23,12 +23,14 @@ import java.io.ByteArrayInputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; +import java.util.HashMap; import java.util.Map; import org.antlr.v4.runtime.ANTLRInputStream; import org.antlr.v4.runtime.BailErrorStrategy; import org.antlr.v4.runtime.CommonTokenStream; import org.antlr.v4.runtime.DefaultErrorStrategy; +import org.antlr.v4.runtime.TokenStreamRewriter; import org.antlr.v4.runtime.atn.PredictionMode; import org.antlr.v4.runtime.misc.ParseCancellationException; import org.antlr.v4.runtime.tree.ParseTree; @@ -74,6 +76,15 @@ import org.apache.sysml.parser.dml.DmlParser.StatementContext; public class DMLParserWrapper extends ParserWrapper { private static final Log LOG = LogFactory.getLog(DMLScript.class.getName()); + + // Rewriter is only used in getInlineableMethods + private TokenStreamRewriter rewriter = null; + + // The below fields are set in the createAST method + // Can be null or the path to the DML file + private String fileName; + // Can be null or the DML script. Note: both fileName and DML script should not be null + private String dmlScript; /** * Parses the passed file with command line parameters. You can either pass both (local file) or just dmlScript (hdfs) or just file name (import command) @@ -88,17 +99,72 @@ public class DMLParserWrapper extends ParserWrapper } /** - * This function is supposed to be called directly only from DmlSyntacticValidator when it encounters 'import' - * @param fileName script file name - * @param dmlScript script file contents - * @param sourceNamespace namespace from source statement - * @param argVals script arguments - * @return dml program, or null if at least one error + * Performs preprocess using DmlPreprocessor listener class. + * + * @param tree parse tree generated by createAST method + * @param errorListener listener that captures potential syntactic errors + * @return a parse tree walker to perform further validation */ - public DMLProgram doParse(String fileName, String dmlScript, String sourceNamespace, Map<String,String> argVals) { - DMLProgram dmlPgm = null; + ParseTreeWalker preprocess(ParseTree tree, CustomErrorListener errorListener) { + ParseTreeWalker walker = new ParseTreeWalker(); + // Get list of function definitions which take precedence over built-in functions if same name + walker.walk(new DmlPreprocessor(errorListener), tree); + return walker; + } + + /** + * Get the inline-able methods + * + * @param fileName1 can be null or the path to the DML file + * @param dmlScript1 can be null or the DML script. Note, both fileName and DML script should not be null. + * @param sourceNamespace source namespace + * @param argVals command-line arguments + * @return hashmap of inline-able methods + */ + public HashMap<String, InlineableMethods> getInlineableMethods(String fileName1, String dmlScript1, String sourceNamespace, Map<String,String> argVals) { + // Create AST and do preprocessing + CustomErrorListener errorListener = new CustomErrorListener(); + ParseTree tree = createAST(fileName1, dmlScript1, sourceNamespace, argVals, errorListener, true); + ParseTreeWalker walker = preprocess(tree, errorListener); + + // Note: this method uses InlineHelper as a listener to perform rewriting of local variables + // It does so in two phases: + // Phase 1. Rewriting phase where local variables are rewritten by adding a prefix. + // Phase 2. Capture the body of the functions using InlineableMethods class + + // Rewrite all the local variables by adding prefix + InlineHelper validator = new InlineHelper(errorListener, argVals, sourceNamespace, errorListener.getFunctionDefs(), rewriter); + validator.setPhase(true); + walker.walk(validator, tree); - ANTLRInputStream in; + // Use the rewritten text as the new DML script and create AST again + fileName = null; dmlScript = rewriter.getText(); + errorListener = new CustomErrorListener(); + tree = createAST(fileName, dmlScript, sourceNamespace, argVals, errorListener, true); + walker = preprocess(tree, errorListener); + + // Put the content of rewritten function body in the inlineMap + validator.setPhase(false); + walker.walk(validator, tree); + + return validator.inlineMap; + } + + /** + * Create an ANTLR parse tree for the input DML script + * + * @param fileName1 can be null or the path to the DML file + * @param dmlScript1 can be null or the DML script. Note, both fileName and DML script should not be null. + * @param sourceNamespace source namespace + * @param argVals command-line arguments + * @param errorListener listener that captures potential syntactic errors + * @param performRewriting should perform rewriting of tokens + * @return a parse tree + */ + private ParseTree createAST(String fileName1, String dmlScript1, String sourceNamespace, Map<String,String> argVals, CustomErrorListener errorListener, boolean performRewriting) { + ANTLRInputStream in = null; + this.fileName = fileName1; + this.dmlScript = dmlScript1; try { if(dmlScript == null) { dmlScript = readDMLScript(fileName, LOG); @@ -113,13 +179,13 @@ public class DMLParserWrapper extends ParserWrapper } catch (LanguageException e) { throw new ParseException(e.getMessage(), e); } - - ProgramrootContext ast = null; - CustomErrorListener errorListener = new CustomErrorListener(); + ProgramrootContext ast = null; try { DmlLexer lexer = new DmlLexer(in); CommonTokenStream tokens = new CommonTokenStream(lexer); + if(performRewriting) + rewriter = new TokenStreamRewriter(tokens); DmlParser antlr4Parser = new DmlParser(tokens); boolean tryOptimizedParsing = false; // For now no optimization, since it is not able to parse integer value. @@ -163,19 +229,31 @@ public class DMLParserWrapper extends ParserWrapper catch(Exception e) { throw new ParseException("ERROR: Cannot parse the program:" + fileName, e); } + return ast; + } + + + + /** + * This function is supposed to be called directly only from DmlSyntacticValidator when it encounters 'import' + * + * @param fileName1 script file name + * @param dmlScript1 script file contents + * @param sourceNamespace namespace from source statement + * @param argVals script arguments + * @return dml program, or null if at least one error + */ + public DMLProgram doParse(String fileName1, String dmlScript1, String sourceNamespace, Map<String,String> argVals) { + // Create AST and do preprocessing + CustomErrorListener errorListener = new CustomErrorListener(); + ParseTree tree = createAST(fileName1, dmlScript1, sourceNamespace, argVals, errorListener, false); + ParseTreeWalker walker = preprocess(tree, errorListener); - - // Now convert the parse tree into DMLProgram - // Do syntactic validation while converting - ParseTree tree = ast; - // And also do syntactic validation - ParseTreeWalker walker = new ParseTreeWalker(); - // Get list of function definitions which take precedence over built-in functions if same name - DmlPreprocessor prep = new DmlPreprocessor(errorListener); - walker.walk(prep, tree); - // Syntactic validation - DmlSyntacticValidator validator = new DmlSyntacticValidator(errorListener, argVals, sourceNamespace, prep.getFunctionDefs()); + // Perform syntactic validation using DmlSyntacticValidator listener + DmlSyntacticValidator validator = new DmlSyntacticValidator(errorListener, argVals, sourceNamespace, errorListener.getFunctionDefs()); walker.walk(validator, tree); + + // Check for parse issues and warning errorListener.unsetCurrentFileName(); this.parseIssues = errorListener.getParseIssues(); this.atLeastOneWarning = errorListener.isAtLeastOneWarning(); @@ -186,9 +264,9 @@ public class DMLParserWrapper extends ParserWrapper if (atLeastOneWarning) { LOG.warn(CustomErrorListener.generateParseIssuesMessage(dmlScript, parseIssues)); } - dmlPgm = createDMLProgram(ast, sourceNamespace); - return dmlPgm; + // Create and return the DML program + return createDMLProgram((ProgramrootContext)tree, sourceNamespace); } private static DMLProgram createDMLProgram(ProgramrootContext ast, String sourceNamespace) { @@ -255,4 +333,4 @@ public class DMLParserWrapper extends ParserWrapper return dmlPgm; } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/Dml.g4 ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 b/src/main/java/org/apache/sysml/parser/dml/Dml.g4 index 46ee178..08e1288 100644 --- a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 +++ b/src/main/java/org/apache/sysml/parser/dml/Dml.g4 @@ -216,6 +216,6 @@ COMMANDLINE_POSITION_ID: '$' DIGIT+; STRING: '"' ( ESC | ~[\\"] )*? '"' | '\'' ( ESC | ~[\\'] )*? '\''; fragment ESC : '\\' [btnfr"'\\] ; // Comments, whitespaces and new line -LINE_COMMENT : '#' .*? '\r'? '\n' -> skip ; -MULTILINE_BLOCK_COMMENT : '/*' .*? '*/' -> skip ; -WHITESPACE : (' ' | '\t' | '\r' | '\n')+ -> skip ; +LINE_COMMENT : '#' .*? '\r'? '\n' -> channel(HIDDEN) ; +MULTILINE_BLOCK_COMMENT : '/*' .*? '*/' -> channel(HIDDEN) ; +WHITESPACE : (' ' | '\t' | '\r' | '\n')+ -> channel(HIDDEN) ; http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java index 56eb8ca..49e96e7 100644 --- a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java +++ b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java @@ -20,7 +20,6 @@ package org.apache.sysml.parser.dml; import java.util.HashSet; -import java.util.Set; import org.antlr.v4.runtime.ParserRuleContext; import org.antlr.v4.runtime.Token; @@ -84,16 +83,10 @@ import org.apache.sysml.parser.dml.DmlParser.WhileStatementContext; public class DmlPreprocessor implements DmlListener { protected final CustomErrorListener errorListener; - // Names of user internal and external functions definitions - protected Set<String> functions; public DmlPreprocessor(CustomErrorListener errorListener) { this.errorListener = errorListener; - functions = new HashSet<>(); - } - - public Set<String> getFunctionDefs() { - return functions; + this.errorListener.functions = new HashSet<>(); } @Override @@ -113,8 +106,8 @@ public class DmlPreprocessor implements DmlListener { public void exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {} protected void validateFunctionName(String name, ParserRuleContext ctx) { - if (!functions.contains(name)) { - functions.add(name); + if (!errorListener.functions.contains(name)) { + errorListener.functions.add(name); } else { notifyErrorListeners("Function Name Conflict: '" + name + "' already defined in " + errorListener.getCurrentFileName(), ctx.start); http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java b/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java new file mode 100644 index 0000000..34d886c --- /dev/null +++ b/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java @@ -0,0 +1,798 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.parser.dml; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Random; +import java.util.Set; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.Token; +import org.antlr.v4.runtime.TokenStreamRewriter; +import org.antlr.v4.runtime.tree.ErrorNode; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.apache.sysml.parser.Expression; +import org.apache.sysml.parser.ParameterExpression; +import org.apache.sysml.parser.common.CommonSyntacticValidator; +import org.apache.sysml.parser.common.CustomErrorListener; +import org.apache.sysml.parser.dml.DmlParser.AccumulatorAssignmentStatementContext; +import org.apache.sysml.parser.dml.DmlParser.AddSubExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.AssignmentStatementContext; +import org.apache.sysml.parser.dml.DmlParser.AtomicExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.BooleanAndExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.BooleanNotExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.BooleanOrExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.BuiltinFunctionExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.CommandlineParamExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.CommandlinePositionExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ConstDoubleIdExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ConstFalseExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ConstIntIdExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ConstStringIdExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ConstTrueExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.DataIdExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ExternalFunctionDefExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ForStatementContext; +import org.apache.sysml.parser.dml.DmlParser.FunctionCallAssignmentStatementContext; +import org.apache.sysml.parser.dml.DmlParser.FunctionCallMultiAssignmentStatementContext; +import org.apache.sysml.parser.dml.DmlParser.IfStatementContext; +import org.apache.sysml.parser.dml.DmlParser.IfdefAssignmentStatementContext; +import org.apache.sysml.parser.dml.DmlParser.ImportStatementContext; +import org.apache.sysml.parser.dml.DmlParser.IndexedExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.InternalFunctionDefExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.IterablePredicateColonExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.IterablePredicateSeqExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.MatrixDataTypeCheckContext; +import org.apache.sysml.parser.dml.DmlParser.MatrixMulExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.Ml_typeContext; +import org.apache.sysml.parser.dml.DmlParser.ModIntDivExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.MultDivExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.MultiIdExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ParForStatementContext; +import org.apache.sysml.parser.dml.DmlParser.ParameterizedExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.PathStatementContext; +import org.apache.sysml.parser.dml.DmlParser.PowerExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ProgramrootContext; +import org.apache.sysml.parser.dml.DmlParser.RelationalExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.SimpleDataIdentifierExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.StatementContext; +import org.apache.sysml.parser.dml.DmlParser.StrictParameterizedExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.StrictParameterizedKeyValueStringContext; +import org.apache.sysml.parser.dml.DmlParser.TypedArgAssignContext; +import org.apache.sysml.parser.dml.DmlParser.TypedArgNoAssignContext; +import org.apache.sysml.parser.dml.DmlParser.UnaryExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.ValueTypeContext; +import org.apache.sysml.parser.dml.DmlParser.WhileStatementContext; + +/** + * This class is used to generate inline-able methods. + * It does so in two phases: + * - Phase 1. Rewriting phase where local variables are rewritten by adding a prefix. + * - Phase 2. Capture the body of the functions using InlineableMethods class + */ +public class InlineHelper extends CommonSyntacticValidator implements DmlListener { + final static String ARG_PREFIX; + static { + Random rand = new Random(); + ARG_PREFIX = "INTERNAL_PREFIX_" + Math.abs(rand.nextLong()) + "_" + Math.abs(rand.nextLong()) + "_"; + } + public HashMap<String, InlineableMethods> inlineMap = new HashMap<>(); + TokenStreamRewriter rewriter; + + // Set internally + HashSet<String> variables = new HashSet<>(); + String currentFunction = null; + boolean isRewritePhase; + + public InlineHelper(CustomErrorListener errorListener, Map<String, String> argVals, String sourceNamespace, + Set<String> prepFunctions, TokenStreamRewriter rewriter1) { + super(errorListener, argVals, sourceNamespace, prepFunctions); + rewriter = rewriter1; + } + + void setPhase(boolean isRewritePhase1) { + isRewritePhase = isRewritePhase1; + } + + + @Override + public void enterInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) { + currentFunction = ctx.name.getText(); + variables.clear(); + } + + @Override + public void exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) { + if(!isRewritePhase) { + StringBuilder sb = new StringBuilder(); + for(StatementContext stmt : ctx.body) { + sb.append(stmt.getText()); + sb.append("\n"); + } + ArrayList<String> inputArgs = new ArrayList<>(); + for(TypedArgAssignContext in : ctx.inputParams) { + inputArgs.add(ARG_PREFIX + in.paramName.getText()); + } + ArrayList<String> retVariables = new ArrayList<>(); + for(TypedArgNoAssignContext out : ctx.outputParams) { + retVariables.add(ARG_PREFIX + out.paramName.getText()); + } + + inlineMap.put(currentFunction, new InlineableMethods(currentFunction, sb.toString(), variables, inputArgs, retVariables)); + } + currentFunction = null; + variables.clear(); + } + + @Override + public void enterIndexedExpression(IndexedExpressionContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertBefore(ctx.name, " " + ARG_PREFIX); + } + } + + @Override + public void exitIndexedExpression(IndexedExpressionContext ctx) { + if(currentFunction != null) + variables.add(ctx.name.getText()); + } + + @Override + public void enterSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertBefore(ctx.start, " " + ARG_PREFIX); + rewriter.insertAfter(ctx.stop, " "); + } + } + + @Override + public void exitSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) { + if(currentFunction != null) + variables.add(ctx.getText()); + } + + @Override + public void enterForStatement(ForStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertBefore(ctx.iterVar, " " + ARG_PREFIX); + rewriter.insertAfter(ctx.iterVar, " "); + } + } + + @Override + public void enterParForStatement(ParForStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertBefore(ctx.iterVar, " " + ARG_PREFIX); + rewriter.insertAfter(ctx.iterVar, " "); + } + } + + @Override + public void exitForStatement(ForStatementContext ctx) { + if(currentFunction != null) + variables.add(ctx.iterVar.getText()); + if(currentFunction != null && isRewritePhase) { + if(ctx.body != null && ctx.body.size() > 0) + rewriter.insertBefore(ctx.body.get(0).start, "\n"); + rewriter.insertAfter(ctx.stop, "\n"); + } + } + + @Override + public void exitParForStatement(ParForStatementContext ctx) { + if(currentFunction != null) + variables.add(ctx.iterVar.getText()); + if(currentFunction != null && isRewritePhase) { + if(ctx.body != null && ctx.body.size() > 0) + rewriter.insertBefore(ctx.body.get(0).start, "\n"); + rewriter.insertAfter(ctx.stop, "\n"); + } + } + + + @Override + protected ConvertedDMLSyntax convertToDMLSyntax(ParserRuleContext ctx, String namespace, String functionName, + ArrayList<ParameterExpression> paramExpression, Token fnName) { + + return null; + } + + @Override + public void enterAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) { + + + } + + @Override + public void enterAddSubExpression(AddSubExpressionContext ctx) { + + + } + + @Override + public void enterAssignmentStatement(AssignmentStatementContext ctx) { + + + } + + @Override + public void enterAtomicExpression(AtomicExpressionContext ctx) { + + + } + + @Override + public void enterBooleanAndExpression(BooleanAndExpressionContext ctx) { + + + } + + @Override + public void enterBooleanNotExpression(BooleanNotExpressionContext ctx) { + + + } + + @Override + public void enterBooleanOrExpression(BooleanOrExpressionContext ctx) { + + + } + + @Override + public void enterBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) { + + + } + + @Override + public void enterCommandlineParamExpression(CommandlineParamExpressionContext ctx) { + + + } + + @Override + public void enterCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) { + + + } + + @Override + public void enterConstDoubleIdExpression(ConstDoubleIdExpressionContext ctx) { + + + } + + @Override + public void enterConstFalseExpression(ConstFalseExpressionContext ctx) { + + + } + + @Override + public void enterConstIntIdExpression(ConstIntIdExpressionContext ctx) { + + + } + + @Override + public void enterConstStringIdExpression(ConstStringIdExpressionContext ctx) { + + + } + + @Override + public void enterConstTrueExpression(ConstTrueExpressionContext ctx) { + + + } + + @Override + public void enterDataIdExpression(DataIdExpressionContext ctx) { + + + } + + @Override + public void enterEveryRule(ParserRuleContext arg0) { + + + } + + @Override + public void enterExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) { + + + } + + + @Override + public void enterFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) { + + + } + + @Override + public void enterFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) { + + + } + + @Override + public void enterIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) { + + + } + + @Override + public void enterIfStatement(IfStatementContext ctx) { + + + } + + @Override + public void enterImportStatement(ImportStatementContext ctx) { + + + } + + @Override + public void enterIterablePredicateColonExpression(IterablePredicateColonExpressionContext ctx) { + + + } + + @Override + public void enterIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) { + + + } + + @Override + public void enterMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) { + + + } + + @Override + public void enterMatrixMulExpression(MatrixMulExpressionContext ctx) { + + + } + + @Override + public void enterMl_type(Ml_typeContext ctx) { + + + } + + @Override + public void enterModIntDivExpression(ModIntDivExpressionContext ctx) { + + + } + + @Override + public void enterMultDivExpression(MultDivExpressionContext ctx) { + + + } + + @Override + public void enterMultiIdExpression(MultiIdExpressionContext ctx) { + + + } + + @Override + public void enterParameterizedExpression(ParameterizedExpressionContext ctx) { + + + } + + @Override + public void enterPathStatement(PathStatementContext ctx) { + + + } + + @Override + public void enterPowerExpression(PowerExpressionContext ctx) { + + + } + + @Override + public void enterProgramroot(ProgramrootContext ctx) { + + + } + + @Override + public void enterRelationalExpression(RelationalExpressionContext ctx) { + + + } + + @Override + public void enterStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) { + + + } + + @Override + public void enterStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext ctx) { + + + } + + @Override + public void enterTypedArgAssign(TypedArgAssignContext ctx) { + + + } + + @Override + public void enterTypedArgNoAssign(TypedArgNoAssignContext ctx) { + + + } + + @Override + public void enterUnaryExpression(UnaryExpressionContext ctx) { + + + } + + @Override + public void enterValueType(ValueTypeContext ctx) { + + + } + + @Override + public void enterWhileStatement(WhileStatementContext ctx) { + + + } + + @Override + public void exitAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertAfter(ctx.stop, ";\n"); + } + + } + + @Override + public void exitAddSubExpression(AddSubExpressionContext ctx) { + + + } + + @Override + public void exitAssignmentStatement(AssignmentStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertAfter(ctx.stop, ";\n"); + } + + } + + @Override + public void exitAtomicExpression(AtomicExpressionContext ctx) { + + + } + + @Override + public void exitBooleanAndExpression(BooleanAndExpressionContext ctx) { + + + } + + @Override + public void exitBooleanNotExpression(BooleanNotExpressionContext ctx) { + + + } + + @Override + public void exitBooleanOrExpression(BooleanOrExpressionContext ctx) { + + + } + + @Override + public void exitBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) { + + + } + + @Override + public void exitCommandlineParamExpression(CommandlineParamExpressionContext ctx) { + + + } + + @Override + public void exitCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) { + + + } + + @Override + public void exitConstDoubleIdExpression(ConstDoubleIdExpressionContext ctx) { + + + } + + @Override + public void exitConstFalseExpression(ConstFalseExpressionContext ctx) { + + + } + + @Override + public void exitConstIntIdExpression(ConstIntIdExpressionContext ctx) { + + + } + + @Override + public void exitConstStringIdExpression(ConstStringIdExpressionContext ctx) { + + + } + + @Override + public void exitConstTrueExpression(ConstTrueExpressionContext ctx) { + + + } + + @Override + public void exitDataIdExpression(DataIdExpressionContext ctx) { + + + } + + @Override + public void exitEveryRule(ParserRuleContext arg0) { + + + } + + @Override + public void exitExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) { + + + } + + + + @Override + public void exitFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertAfter(ctx.stop, ";\n"); + } + + } + + @Override + public void exitFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertAfter(ctx.stop, ";\n"); + } + + } + + @Override + public void exitIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertAfter(ctx.stop, ";\n"); + } + + } + + @Override + public void exitIfStatement(IfStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + if(ctx.ifBody != null && ctx.ifBody.size() > 0) + rewriter.insertBefore(ctx.ifBody.get(0).start, "\n"); + if(ctx.elseBody != null && ctx.elseBody.size() > 0) + rewriter.insertBefore(ctx.elseBody.get(0).start, "\n"); + rewriter.insertAfter(ctx.stop, "\n"); + } + + } + + @Override + public void exitImportStatement(ImportStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertAfter(ctx.stop, ";\n"); + } + + } + + @Override + public void exitIterablePredicateColonExpression(IterablePredicateColonExpressionContext ctx) { + + + } + + @Override + public void exitIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) { + + + } + + @Override + public void exitMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) { + + + } + + @Override + public void exitMatrixMulExpression(MatrixMulExpressionContext ctx) { + + + } + + @Override + public void exitMl_type(Ml_typeContext ctx) { + + + } + + @Override + public void exitModIntDivExpression(ModIntDivExpressionContext ctx) { + + + } + + @Override + public void exitMultDivExpression(MultDivExpressionContext ctx) { + + + } + + @Override + public void exitMultiIdExpression(MultiIdExpressionContext ctx) { + + + } + + @Override + public void exitParameterizedExpression(ParameterizedExpressionContext ctx) { + + + } + + @Override + public void exitPathStatement(PathStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + rewriter.insertAfter(ctx.stop, ";\n"); + } + + } + + @Override + public void exitPowerExpression(PowerExpressionContext ctx) { + + + } + + @Override + public void exitProgramroot(ProgramrootContext ctx) { + + + } + + @Override + public void exitRelationalExpression(RelationalExpressionContext ctx) { + + + } + + @Override + public void exitStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) { + + + } + + @Override + public void exitStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext ctx) { + + + } + + @Override + public void exitTypedArgAssign(TypedArgAssignContext ctx) { + + + } + + @Override + public void exitTypedArgNoAssign(TypedArgNoAssignContext ctx) { + + + } + + @Override + public void exitUnaryExpression(UnaryExpressionContext ctx) { + + + } + + @Override + public void exitValueType(ValueTypeContext ctx) { + + + } + + @Override + public void exitWhileStatement(WhileStatementContext ctx) { + if(currentFunction != null && isRewritePhase) { + if(ctx.body != null && ctx.body.size() > 0) + rewriter.insertBefore(ctx.body.get(0).start, "\n"); + rewriter.insertAfter(ctx.stop, "\n"); + } + } + + @Override + public String falseStringLiteral() { + + return null; + } + + @Override + protected Expression handleLanguageSpecificFunction(ParserRuleContext ctx, String functionName, + ArrayList<ParameterExpression> paramExpressions) { + + return null; + } + + @Override + public String namespaceResolutionOp() { + + return null; + } + + @Override + public String trueStringLiteral() { + + return null; + } + + @Override + public void visitErrorNode(ErrorNode arg0) { + + + } + + @Override + public void visitTerminal(TerminalNode arg0) { + + + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java b/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java new file mode 100644 index 0000000..d3b0d11 --- /dev/null +++ b/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.parser.dml; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Random; + +import org.apache.sysml.runtime.DMLRuntimeException; + +/** + * This class is a simple container class used to hold the function to be inlined. + * It contains the function name, body and also the input and return arguments. + * The user invokes getInlinedDML method to get the inlined DML code. + */ +public class InlineableMethods { + ArrayList<String> _variables; + final String _body; + final String _fnName; + final ArrayList<String> _inputArgs; + final ArrayList<String> _retVariables; + static int CALLER_ID = 1; + + public InlineableMethods(String fnName, String body, HashSet<String> variables, ArrayList<String> inputArgs, ArrayList<String> retVariables) { + _fnName = fnName; + _body = body; + _variables = new ArrayList<String>(variables); + _variables.sort(Comparator.comparing(String::length).reversed()); + _inputArgs = inputArgs; + _retVariables = retVariables; + } + + public ArrayList<String> getLocalVariables() { + return _variables; + } + + private String _getInlinedDML(HashMap<String, String> actualArguments) { + String ret = _body; + int callerID = CALLER_ID++; + for(String var : _variables) { + String originalVarName = var.substring(InlineHelper.ARG_PREFIX.length()); + if(actualArguments.containsKey(var)) { + ret = ret.replaceAll(var, actualArguments.get(var)); + } + else { + // internal argument + ret = ret.replaceAll(var, LOCAL_ARG_PREFIX + _fnName + "_" + callerID + "_" + originalVarName); + } + } + return ret; + } + + public String getInlinedDML(ArrayList<String> actualInputArgs, ArrayList<String> actualRetVariables) { + HashMap<String, String> actualArguments = new HashMap<>(); + if(actualInputArgs.size() != _inputArgs.size()) { + throw new DMLRuntimeException("Incorrect number of input arguments for the function " + _fnName + ": expected " + + _inputArgs.size() + " (" + String.join(", ", _inputArgs) + ") but found " + actualInputArgs.size() + + " (" + String.join(", ", actualInputArgs) + ")"); + } + if(actualRetVariables.size() != _retVariables.size()) { + throw new DMLRuntimeException("Incorrect number of return variables for the function " + _fnName + ": expected " + + _retVariables.size() + " (" + String.join(", ", _retVariables) + ") but found " + actualRetVariables.size() + + " (" + String.join(", ", actualRetVariables) + ")"); + } + for(int i = 0; i < _inputArgs.size(); i++) { + actualArguments.put(_inputArgs.get(i), actualInputArgs.get(i)); + } + for(int i = 0; i < _retVariables.size(); i++) { + actualArguments.put(_retVariables.get(i), actualRetVariables.get(i)); + } + return _getInlinedDML(actualArguments); + } + + static final String LOCAL_ARG_PREFIX; + static { + Random rand = new Random(); + LOCAL_ARG_PREFIX = "LOCAL_" + Math.abs(rand.nextLong()) + "_" + Math.abs(rand.nextLong()); +// LOCAL_ARG_PREFIX = "LOCAL_"; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java index 03bc3b3..15dd23e 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java @@ -307,7 +307,14 @@ public abstract class CacheableData<T extends CacheBlock> extends Data } public MatrixCharacteristics getMatrixCharacteristics() { - return _metaData.getMatrixCharacteristics(); + MatrixCharacteristics mc = _metaData.getMatrixCharacteristics(); + if(mc.getRowsPerBlock() == -1) { + mc.setRowsPerBlock(OptimizerUtils.DEFAULT_BLOCKSIZE); + } + if(mc.getColsPerBlock() == -1) { + mc.setColsPerBlock(OptimizerUtils.DEFAULT_BLOCKSIZE); + } + return mc; } public abstract void refreshMetaData(); http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java index 6772b4a..a08d4fd 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java @@ -298,7 +298,7 @@ public class GPUMemoryManager { evictOrClear(sizeBasedUnlockedGPUObjects.get(), opcode); A = cudaMallocNoWarn(tmpA, size, null); if(A == null) - LOG.warn("cudaMalloc failed after clearing/evicting based on size."); + LOG.debug("cudaMalloc failed after clearing/evicting based on size."); if(ConfigurationManager.isStatistics()) { long totalTime = System.nanoTime() - t0; GPUStatistics.cudaEvictTime.add(totalTime); http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/python/systemml/mllearn/estimators.py ---------------------------------------------------------------------- diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py index fbcd3e2..8a100b4 100644 --- a/src/main/python/systemml/mllearn/estimators.py +++ b/src/main/python/systemml/mllearn/estimators.py @@ -924,7 +924,7 @@ class Caffe2DML(BaseSystemMLClassifier): self.estimator.setWeightsToIgnore(ignore_weights) def set(self, debug=None, train_algo=None, test_algo=None, parallel_batches=None, - output_activations=None, perform_one_hot_encoding=None, parfor_parameters=None): + output_activations=None, perform_one_hot_encoding=None, parfor_parameters=None, inline_nn_library=None): """ Set input to Caffe2DML @@ -937,9 +937,12 @@ class Caffe2DML(BaseSystemMLClassifier): output_activations: (developer flag) directory to output activations of each layer as csv while prediction. To be used only in batch mode (default: None) perform_one_hot_encoding: should perform one-hot encoding in DML using table function (default: False) parfor_parameters: dictionary for parfor parameters when using allreduce-style algorithms (default: "") + inline_nn_library: whether to inline the NN library when generating DML using Caffe2DML (default: False) """ if debug is not None: self.estimator.setInput("$debug", str(debug).upper()) + if inline_nn_library is not None: + self.estimator.setInput("$inline_nn_library", str(inline_nn_library).upper()) if train_algo is not None: self.estimator.setInput("$train_algo", str(train_algo).lower()) if test_algo is not None: http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala index 26e554f..8ddb1fe 100644 --- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala +++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala @@ -119,8 +119,9 @@ object Caffe2DML { val LOG = LogFactory.getLog(classOf[Caffe2DML].getName()) // ------------------------------------------------------------------------ val USE_PLUS_EQ = true - def layerDir = "nn/layers/" - def optimDir = "nn/optim/" + def nnDir = "nn/" + def layerDir = nnDir + "layers/" + def optimDir = nnDir + "optim/" // Naming conventions: val X = "X"; val y = "y"; val batchSize = "BATCH_SIZE"; val numImages = "num_images"; val numValidationImages = "num_validation" @@ -159,6 +160,7 @@ object Caffe2DML { val BATCH_ALGORITHM = "batch" val ALLREDUCE_ALGORITHM = "allreduce" val ALLREDUCE_PARALLEL_BATCHES_ALGORITHM = "allreduce_parallel_batches" + var INLINE_NN_LIBRARY = false } class Caffe2DML(val sc: SparkContext, @@ -312,6 +314,7 @@ class Caffe2DML(val sc: SparkContext, // Flags passed by user val DEBUG_TRAINING = if (inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false + Caffe2DML.INLINE_NN_LIBRARY = if (inputs.containsKey("$inline_nn_library")) inputs.get("$inline_nn_library").toLowerCase.toBoolean else false assign(tabDMLScript, "debug", if (DEBUG_TRAINING) "TRUE" else "FALSE") setDebugFlags(DEBUG_TRAINING) @@ -721,6 +724,7 @@ class Caffe2DMLModel(val numClasses: String, val sc: SparkContext, val solver: C reset // Reset the state of DML generator for training script. val DEBUG_PREDICTION = if (estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false + Caffe2DML.INLINE_NN_LIBRARY = if (estimator.inputs.containsKey("$inline_nn_library")) estimator.inputs.get("$inline_nn_library").toLowerCase.toBoolean else false assign(tabDMLScript, "debug", if (DEBUG_PREDICTION) "TRUE" else "FALSE") estimator.setDebugFlags(DEBUG_PREDICTION) http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala index d664f6e..b290983 100644 --- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala +++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala @@ -29,6 +29,7 @@ import caffe.Caffe.EltwiseParameter.EltwiseOp import org.apache.sysml.runtime.DMLRuntimeException; import java.util.ArrayList import caffe.Caffe.PoolingParameter.PoolMethod +import scala.collection.JavaConverters._ trait CaffeLayer extends BaseDMLGenerator { // ------------------------------------------------- @@ -125,7 +126,7 @@ trait CaffeLayer extends BaseDMLGenerator { // The layers that have a corresponding dml script call this method. // Assumption: the first variable of resultVariables is always dX def invokeBackward(dmlScript: StringBuilder, outSuffix: String, resultVariables: List[String], arguments: String*): Unit = { - invoke(dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList, false) + Utils.invoke(Caffe2DML.layerDir, dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList, false) val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id) dmlScript.append("; ") bottomLayerIDs.map(bottomLayerID => dmlScript.append(dX(bottomLayerID) + outSuffix + " = " + resultVariables(0) + outSuffix + "; ")) @@ -140,6 +141,13 @@ trait CaffeLayer extends BaseDMLGenerator { dmlScript.append("\n") } // -------------------------------------------------------------------------------------- + + def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String]): Unit = + Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, functionName, arguments, true) + def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, appendNewLine: Boolean, arguments: String*): Unit = + Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine) + def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: String*): Unit = + Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, functionName, arguments.toList, true) } trait IsLossLayer extends CaffeLayer { @@ -1603,4 +1611,4 @@ class DeConvolution(val param: LayerParameter, val id: Int, val net: CaffeNetwor if (convParam.hasPadW) convParam.getPadW.toString else if (convParam.getPadCount > 0) convParam.getPad(0).toString else "0" -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala index 8559c60..da963b9 100644 --- a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala +++ b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala @@ -67,10 +67,12 @@ trait CaffeSolver { val hasDecayMult = layer.param.getParamList != null && layer.param.getParamList.size >= 1 && layer.param.getParamList.get(0).hasDecayMult val newLambda = if(hasDecayMult) layer.param.getParamList.get(0).getDecayMult * lambda else lambda - dmlScript.append("\t").append(layer.dWeight + "_reg = " + regularizationSource + "::backward(" + layer.weight + ", " + newLambda + ")\n") + Utils.invoke(Caffe2DML.layerDir, dmlScript, regularizationSource + "::", List[String](layer.dWeight + "_reg"), "backward", + List[String](layer.weight, "" + newLambda), true) dmlScript.append("\t").append(layer.dWeight + " = " + layer.dWeight + " + " + layer.dWeight + "_reg\n") if(layer.shouldUpdateExtraWeight) { - dmlScript.append("\t").append(layer.dExtraWeight + "_reg = " + regularizationSource + "::backward(" + layer.extraWeight + ", " + newLambda + ")\n") + Utils.invoke(Caffe2DML.layerDir, dmlScript, regularizationSource + "::", List[String](layer.dExtraWeight + "_reg"), "backward", + List[String](layer.extraWeight, "" + newLambda), true) dmlScript.append("\t").append(layer.dExtraWeight + " = " + layer.dExtraWeight + " + " + layer.dExtraWeight + "_reg\n") } } @@ -339,32 +341,19 @@ class Nesterov(regularizationType:String = "L2", lambda: Double = 5e-04, momentu * input v. */ def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = { - val fn = if (Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update" - val lastParameter = if (Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else "" + if (!Caffe2DML.USE_NESTEROV_UDF) { regularization_update(regularizationType, lambda, dmlScript, layer) } if (layer.shouldUpdateWeight) - dmlScript - .append("\t") - .append( - "[" + commaSep(layer.weight, layer.weight + "_v") + "] " + - "= " + fn + "(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v") + lastParameter + ")\n" - ) + Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", List[String](layer.weight, layer.weight + "_v"), "update", + List[String](layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v"), true) if (layer.shouldUpdateExtraWeight) - dmlScript - .append("\t") - .append( - "[" + commaSep(layer.extraWeight, layer.extraWeight + "_v") + "] " + - "= " + fn + "(" + commaSep(layer.extraWeight, layer.dExtraWeight, getWeightLr(layer), momentum.toString, layer.extraWeight + "_v") + lastParameter + ")\n" - ) + Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", List[String](layer.extraWeight, layer.extraWeight + "_v"), "update", + List[String](layer.extraWeight, layer.dExtraWeight, getWeightLr(layer), momentum.toString, layer.extraWeight + "_v"), true) if (layer.shouldUpdateBias) - dmlScript - .append("\t") - .append( - "[" + commaSep(layer.bias, layer.bias + "_v") + "] " + - "= " + fn + "(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v") + lastParameter + ")\n" - ) + Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", List[String](layer.bias, layer.bias + "_v"), "update", + List[String](layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v"), true) } def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = { if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_v = sgd_nesterov::init(" + layer.weight + ")\n") @@ -372,4 +361,4 @@ class Nesterov(regularizationType:String = "L2", lambda: Double = 5e-04, momentu if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_v = sgd_nesterov::init(" + layer.bias + ")\n") } def sourceFileName: String = "sgd_nesterov" -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala index 60396f1..59c75ad 100644 --- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala +++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala @@ -75,39 +75,6 @@ trait BaseDMLGenerator { sum(dmlScript, rhsVars) dmlScript.append("\n") } - def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String]): Unit = - invoke(dmlScript, namespace1, returnVariables, functionName, arguments, true) - def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String], appendNewLine: Boolean): Unit = { - if (returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value") - if (returnVariables.length > 1) dmlScript.append("[") - dmlScript.append(returnVariables(0)) - if (returnVariables.length > 1) { - for (i <- 1 until returnVariables.length) { - dmlScript.append(",").append(returnVariables(i)) - } - dmlScript.append("]") - } - dmlScript.append(" = ") - dmlScript.append(namespace1) - dmlScript.append(functionName) - dmlScript.append("(") - if (arguments != null) { - if (arguments.length != 0) - dmlScript.append(arguments(0)) - if (arguments.length > 1) { - for (i <- 1 until arguments.length) { - dmlScript.append(",").append(arguments(i)) - } - } - } - dmlScript.append(")") - if (appendNewLine) - dmlScript.append("\n") - } - def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, appendNewLine: Boolean, arguments: String*): Unit = - invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine) - def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: String*): Unit = - invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, true) def rightIndexing(dmlScript: StringBuilder, lhsVar:String, rhsVar: String, rl: String, ru: String, cl: String=null, cu: String=null): StringBuilder = { dmlScript.append(lhsVar).append(" = ").append(rhsVar).append("[") if (rl != null && ru != null) dmlScript.append(rl).append(":").append(ru) @@ -279,6 +246,7 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator { // Append source statements for layers as well as solver source(net, solver, if (isTraining) Array[String]("l1_reg") else null) source(net, solver, if (isTraining) Array[String]("l2_reg") else null) + source(dmlScript, numTabs, "util", Caffe2DML.nnDir) if (isTraining) { // Append external built-in function headers: @@ -346,4 +314,4 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator { def updateMeanVarianceForBatchNorm(net: CaffeNetwork, value: Boolean): Unit = net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = value) -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/Utils.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/dl/Utils.scala b/src/main/scala/org/apache/sysml/api/dl/Utils.scala index 63aaf91..da53f65 100644 --- a/src/main/scala/org/apache/sysml/api/dl/Utils.scala +++ b/src/main/scala/org/apache/sysml/api/dl/Utils.scala @@ -39,6 +39,11 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock import org.apache.sysml.api.mlcontext.MLContext import org.apache.spark.SparkContext import org.apache.spark.api.java.JavaSparkContext +import org.apache.sysml.parser.ParserWrapper +import org.apache.sysml.parser.dml.DMLParserWrapper +import org.apache.sysml.parser.dml.InlineableMethods +import java.util.ArrayList +import scala.collection.JavaConverters._ object Utils { // --------------------------------------------------------------------------------------------- @@ -64,6 +69,57 @@ object Utils { line = bufReader.readLine() } } + + def readDMLScript(fileName:String):String = ParserWrapper.readDMLScript(fileName, Caffe2DML.LOG) + val inlineableMethods = new java.util.HashMap[String, java.util.HashMap[String, InlineableMethods]]() + def getInlineableMethod(sourceFilePath:String, namespace:String, fnName:String):InlineableMethods = { + if(inlineableMethods.contains(namespace)) + return inlineableMethods.get(namespace).get(fnName) + else { + val ret = new DMLParserWrapper().getInlineableMethods(sourceFilePath, null, namespace, null) + inlineableMethods.put(namespace, ret) + return ret.get(fnName) + } + } + + def invoke(dir:String, dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String], appendNewLine: Boolean): Unit = { + if(Caffe2DML.INLINE_NN_LIBRARY) { + // Caffe2DML.layerDir + // For now, donot inline recursively + val sourceFileName = if(namespace1.endsWith("::")) namespace1.substring(0, namespace1.length() - 2) else namespace1 + val method = getInlineableMethod(dir + sourceFileName + ".dml", namespace1, functionName) + val generatedDML = method.getInlinedDML(new ArrayList[String](arguments.asJava), new ArrayList[String](returnVariables.asJava)) + dmlScript.append(generatedDML) + dmlScript.append("\n") + //System.out.println(generatedDML) + return + } + if (returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value") + if (returnVariables.length > 1) dmlScript.append("[") + dmlScript.append(returnVariables(0)) + if (returnVariables.length > 1) { + for (i <- 1 until returnVariables.length) { + dmlScript.append(",").append(returnVariables(i)) + } + dmlScript.append("]") + } + dmlScript.append(" = ") + dmlScript.append(namespace1) + dmlScript.append(functionName) + dmlScript.append("(") + if (arguments != null) { + if (arguments.length != 0) + dmlScript.append(arguments(0)) + if (arguments.length > 1) { + for (i <- 1 until arguments.length) { + dmlScript.append(",").append(arguments(i)) + } + } + } + dmlScript.append(")") + if (appendNewLine) + dmlScript.append("\n") + } // --------------------------------------------------------------------------------------------- def parseSolver(solverFilePath: String): CaffeSolver = parseSolver(readCaffeSolver(solverFilePath)) @@ -324,4 +380,4 @@ class Utils { def saveCaffeModelFile(sc: JavaSparkContext, deployFilePath: String, caffeModelFilePath: String, outputDirectory: String, format: String): Unit = Utils.saveCaffeModelFile(sc, deployFilePath, caffeModelFilePath, outputDirectory, format) -} +} \ No newline at end of file