Repository: systemml
Updated Branches:
  refs/heads/master 11c67055a -> ef1945d70


[SYSTEMML-540] Allow user to generate an inlined DML script in Caffe2DML

- The inlining code is generic enough to be extended to perform parser-level 
inlining. This commit allows us to compare the tradeoffs of performing 
script-level inlining v/s hop-level inlining.
- Refactored DMLParserWrapper and also added javadoc.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ef1945d7
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ef1945d7
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ef1945d7

Branch: refs/heads/master
Commit: ef1945d70a85df4f646c315d06a1a094dad6ebb2
Parents: 11c6705
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Thu Oct 11 14:28:59 2018 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Thu Oct 11 14:32:35 2018 -0700

----------------------------------------------------------------------
 .../parser/common/CustomErrorListener.java      |   8 +
 .../sysml/parser/dml/DMLParserWrapper.java      | 130 ++-
 .../java/org/apache/sysml/parser/dml/Dml.g4     |   6 +-
 .../sysml/parser/dml/DmlPreprocessor.java       |  13 +-
 .../apache/sysml/parser/dml/InlineHelper.java   | 798 +++++++++++++++++++
 .../sysml/parser/dml/InlineableMethods.java     |  98 +++
 .../controlprogram/caching/CacheableData.java   |   9 +-
 .../gpu/context/GPUMemoryManager.java           |   2 +-
 src/main/python/systemml/mllearn/estimators.py  |   5 +-
 .../org/apache/sysml/api/dl/Caffe2DML.scala     |   8 +-
 .../org/apache/sysml/api/dl/CaffeLayer.scala    |  12 +-
 .../org/apache/sysml/api/dl/CaffeSolver.scala   |  35 +-
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |  36 +-
 .../scala/org/apache/sysml/api/dl/Utils.scala   |  58 +-
 14 files changed, 1114 insertions(+), 104 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java 
b/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
index 2af5f69..b82afc9 100644
--- a/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
+++ b/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
@@ -22,6 +22,7 @@ package org.apache.sysml.parser.common;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.Set;
 
 import org.antlr.v4.runtime.BaseErrorListener;
 import org.antlr.v4.runtime.RecognitionException;
@@ -38,6 +39,9 @@ public class CustomErrorListener extends BaseErrorListener {
        private boolean atLeastOneError = false;
        private boolean atLeastOneWarning = false;
        private String currentFileName = null;
+       
+       // Names of user internal and external functions definitions
+       public Set<String> functions;
 
        /**
         * List of parse issues.
@@ -55,6 +59,10 @@ public class CustomErrorListener extends BaseErrorListener {
        public void unsetCurrentFileName() {
                currentFileName = null;
        }
+       
+       public Set<String> getFunctionDefs() {
+               return functions;
+       }
 
        /**
         * Validation error occurred. Add the error to the list of parse issues.

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java 
b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
index 1d7daa1..9b3f8c4 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
@@ -23,12 +23,14 @@ import java.io.ByteArrayInputStream;
 import java.io.FileNotFoundException;
 import java.io.IOException;
 import java.io.InputStream;
+import java.util.HashMap;
 import java.util.Map;
 
 import org.antlr.v4.runtime.ANTLRInputStream;
 import org.antlr.v4.runtime.BailErrorStrategy;
 import org.antlr.v4.runtime.CommonTokenStream;
 import org.antlr.v4.runtime.DefaultErrorStrategy;
+import org.antlr.v4.runtime.TokenStreamRewriter;
 import org.antlr.v4.runtime.atn.PredictionMode;
 import org.antlr.v4.runtime.misc.ParseCancellationException;
 import org.antlr.v4.runtime.tree.ParseTree;
@@ -74,6 +76,15 @@ import 
org.apache.sysml.parser.dml.DmlParser.StatementContext;
 public class DMLParserWrapper extends ParserWrapper
 {
        private static final Log LOG = 
LogFactory.getLog(DMLScript.class.getName());
+       
+       // Rewriter is only used in getInlineableMethods
+       private TokenStreamRewriter rewriter = null;
+       
+       // The below fields are set in the createAST method
+       // Can be null or the path to the DML file
+       private String fileName; 
+       // Can be null or the DML script. Note: both fileName and DML script 
should not be null
+       private String dmlScript;
 
        /**
         * Parses the passed file with command line parameters. You can either 
pass both (local file) or just dmlScript (hdfs) or just file name (import 
command)
@@ -88,17 +99,72 @@ public class DMLParserWrapper extends ParserWrapper
        }
        
        /**
-        * This function is supposed to be called directly only from 
DmlSyntacticValidator when it encounters 'import'
-        * @param fileName script file name
-        * @param dmlScript script file contents
-        * @param sourceNamespace namespace from source statement
-        * @param argVals script arguments
-        * @return dml program, or null if at least one error
+        * Performs preprocess using DmlPreprocessor listener class.
+        * 
+        * @param tree parse tree generated by createAST method
+        * @param errorListener listener that captures potential syntactic 
errors 
+        * @return a parse tree walker to perform further validation
         */
-       public DMLProgram doParse(String fileName, String dmlScript, String 
sourceNamespace, Map<String,String> argVals) {
-               DMLProgram dmlPgm = null;
+       ParseTreeWalker preprocess(ParseTree tree, CustomErrorListener 
errorListener) {
+               ParseTreeWalker walker = new ParseTreeWalker();
+               // Get list of function definitions which take precedence over 
built-in functions if same name
+               walker.walk(new DmlPreprocessor(errorListener),  tree);
+               return walker;
+       }
+       
+       /**
+        * Get the inline-able methods
+        * 
+        * @param fileName1 can be null or the path to the DML file
+        * @param dmlScript1 can be null or the DML script. Note, both fileName 
and DML script should not be null.
+        * @param sourceNamespace source namespace
+        * @param argVals command-line arguments
+        * @return hashmap of inline-able methods
+        */
+       public HashMap<String, InlineableMethods> getInlineableMethods(String 
fileName1, String dmlScript1, String sourceNamespace, Map<String,String> 
argVals) {
+               // Create AST and do preprocessing
+               CustomErrorListener errorListener = new CustomErrorListener();
+               ParseTree tree = createAST(fileName1, dmlScript1, 
sourceNamespace, argVals, errorListener, true);
+               ParseTreeWalker walker = preprocess(tree, errorListener);
+                               
+               // Note: this method uses InlineHelper as a listener to perform 
rewriting of local variables
+               // It does so in two phases:
+               // Phase 1. Rewriting phase where local variables are rewritten 
by adding a prefix.
+               // Phase 2. Capture the body of the functions using 
InlineableMethods class
+               
+               // Rewrite all the local variables by adding prefix 
+               InlineHelper validator = new InlineHelper(errorListener, 
argVals, sourceNamespace, errorListener.getFunctionDefs(), rewriter);
+               validator.setPhase(true);
+               walker.walk(validator, tree);
                
-               ANTLRInputStream in;
+               // Use the rewritten text as the new DML script and create AST 
again
+               fileName = null; dmlScript = rewriter.getText();
+               errorListener = new CustomErrorListener();
+               tree = createAST(fileName, dmlScript, sourceNamespace, argVals, 
errorListener, true);
+               walker = preprocess(tree, errorListener);
+                               
+               // Put the content of rewritten function body in the inlineMap
+               validator.setPhase(false);
+               walker.walk(validator, tree);
+               
+               return validator.inlineMap;
+       }
+       
+       /**
+        * Create an ANTLR parse tree for the input DML script
+        * 
+        * @param fileName1 can be null or the path to the DML file
+        * @param dmlScript1 can be null or the DML script. Note, both fileName 
and DML script should not be null.
+        * @param sourceNamespace source namespace
+        * @param argVals command-line arguments
+        * @param errorListener listener that captures potential syntactic 
errors
+        * @param performRewriting should perform rewriting of tokens
+        * @return a parse tree
+        */
+       private ParseTree createAST(String fileName1, String dmlScript1, String 
sourceNamespace, Map<String,String> argVals, CustomErrorListener errorListener, 
boolean performRewriting) {
+               ANTLRInputStream in = null;
+               this.fileName = fileName1;
+               this.dmlScript = dmlScript1;
                try {
                        if(dmlScript == null) {
                                dmlScript = readDMLScript(fileName, LOG);
@@ -113,13 +179,13 @@ public class DMLParserWrapper extends ParserWrapper
                } catch (LanguageException e) {
                        throw new ParseException(e.getMessage(), e);
                }
-
-               ProgramrootContext ast = null;
-               CustomErrorListener errorListener = new CustomErrorListener();
                
+               ProgramrootContext ast = null;
                try {
                        DmlLexer lexer = new DmlLexer(in);
                        CommonTokenStream tokens = new CommonTokenStream(lexer);
+                       if(performRewriting)
+                               rewriter = new TokenStreamRewriter(tokens);
                        DmlParser antlr4Parser = new DmlParser(tokens);
                        
                        boolean tryOptimizedParsing = false; // For now no 
optimization, since it is not able to parse integer value. 
@@ -163,19 +229,31 @@ public class DMLParserWrapper extends ParserWrapper
                catch(Exception e) {
                        throw new ParseException("ERROR: Cannot parse the 
program:" + fileName, e);
                }
+               return ast;
+       }
+       
+       
+       
+       /**
+        * This function is supposed to be called directly only from 
DmlSyntacticValidator when it encounters 'import'
+        * 
+        * @param fileName1 script file name
+        * @param dmlScript1 script file contents
+        * @param sourceNamespace namespace from source statement
+        * @param argVals script arguments
+        * @return dml program, or null if at least one error
+        */
+       public DMLProgram doParse(String fileName1, String dmlScript1, String 
sourceNamespace, Map<String,String> argVals) {
+               // Create AST and do preprocessing
+               CustomErrorListener errorListener = new CustomErrorListener();
+               ParseTree tree = createAST(fileName1, dmlScript1, 
sourceNamespace, argVals, errorListener, false);
+               ParseTreeWalker walker = preprocess(tree, errorListener);
                
-
-               // Now convert the parse tree into DMLProgram
-               // Do syntactic validation while converting 
-               ParseTree tree = ast;
-               // And also do syntactic validation
-               ParseTreeWalker walker = new ParseTreeWalker();
-               // Get list of function definitions which take precedence over 
built-in functions if same name
-               DmlPreprocessor prep = new DmlPreprocessor(errorListener);
-               walker.walk(prep,  tree);
-               // Syntactic validation
-               DmlSyntacticValidator validator = new 
DmlSyntacticValidator(errorListener, argVals, sourceNamespace, 
prep.getFunctionDefs());
+               // Perform syntactic validation using DmlSyntacticValidator 
listener
+               DmlSyntacticValidator validator = new 
DmlSyntacticValidator(errorListener, argVals, sourceNamespace, 
errorListener.getFunctionDefs());
                walker.walk(validator, tree);
+               
+               // Check for parse issues and warning
                errorListener.unsetCurrentFileName();
                this.parseIssues = errorListener.getParseIssues();
                this.atLeastOneWarning = errorListener.isAtLeastOneWarning();
@@ -186,9 +264,9 @@ public class DMLParserWrapper extends ParserWrapper
                if (atLeastOneWarning) {
                        
LOG.warn(CustomErrorListener.generateParseIssuesMessage(dmlScript, 
parseIssues));
                }
-               dmlPgm = createDMLProgram(ast, sourceNamespace);
                
-               return dmlPgm;
+               // Create and return the DML program
+               return createDMLProgram((ProgramrootContext)tree, 
sourceNamespace);
        }
        
        private static DMLProgram createDMLProgram(ProgramrootContext ast, 
String sourceNamespace) {
@@ -255,4 +333,4 @@ public class DMLParserWrapper extends ParserWrapper
                
                return dmlPgm;
        }
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/Dml.g4
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 
b/src/main/java/org/apache/sysml/parser/dml/Dml.g4
index 46ee178..08e1288 100644
--- a/src/main/java/org/apache/sysml/parser/dml/Dml.g4
+++ b/src/main/java/org/apache/sysml/parser/dml/Dml.g4
@@ -216,6 +216,6 @@ COMMANDLINE_POSITION_ID: '$' DIGIT+;
 STRING: '"' ( ESC | ~[\\"] )*? '"' | '\'' ( ESC | ~[\\'] )*? '\'';
 fragment ESC : '\\' [btnfr"'\\] ;
 // Comments, whitespaces and new line
-LINE_COMMENT : '#' .*? '\r'? '\n' -> skip ;
-MULTILINE_BLOCK_COMMENT : '/*' .*? '*/' -> skip ;
-WHITESPACE : (' ' | '\t' | '\r' | '\n')+ -> skip ;
+LINE_COMMENT : '#' .*? '\r'? '\n' -> channel(HIDDEN) ;
+MULTILINE_BLOCK_COMMENT : '/*' .*? '*/' -> channel(HIDDEN) ;
+WHITESPACE : (' ' | '\t' | '\r' | '\n')+ -> channel(HIDDEN) ;

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java 
b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
index 56eb8ca..49e96e7 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
@@ -20,7 +20,6 @@
 package org.apache.sysml.parser.dml;
 
 import java.util.HashSet;
-import java.util.Set;
 
 import org.antlr.v4.runtime.ParserRuleContext;
 import org.antlr.v4.runtime.Token;
@@ -84,16 +83,10 @@ import 
org.apache.sysml.parser.dml.DmlParser.WhileStatementContext;
 public class DmlPreprocessor implements DmlListener {
 
        protected final CustomErrorListener errorListener;
-       // Names of user internal and external functions definitions
-       protected Set<String> functions;
 
        public DmlPreprocessor(CustomErrorListener errorListener) {
                this.errorListener = errorListener;
-               functions = new HashSet<>();
-       }
-
-       public Set<String> getFunctionDefs() {
-               return functions;
+               this.errorListener.functions = new HashSet<>();
        }
        
        @Override
@@ -113,8 +106,8 @@ public class DmlPreprocessor implements DmlListener {
        public void 
exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {}
 
        protected void validateFunctionName(String name, ParserRuleContext ctx) 
{
-               if (!functions.contains(name)) {
-                       functions.add(name);
+               if (!errorListener.functions.contains(name)) {
+                       errorListener.functions.add(name);
                }
                else {
                        notifyErrorListeners("Function Name Conflict: '" + name 
+ "' already defined in " + errorListener.getCurrentFileName(), ctx.start);

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java 
b/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java
new file mode 100644
index 0000000..34d886c
--- /dev/null
+++ b/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java
@@ -0,0 +1,798 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.parser.dml;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+import org.antlr.v4.runtime.ParserRuleContext;
+import org.antlr.v4.runtime.Token;
+import org.antlr.v4.runtime.TokenStreamRewriter;
+import org.antlr.v4.runtime.tree.ErrorNode;
+import org.antlr.v4.runtime.tree.TerminalNode;
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.ParameterExpression;
+import org.apache.sysml.parser.common.CommonSyntacticValidator;
+import org.apache.sysml.parser.common.CustomErrorListener;
+import 
org.apache.sysml.parser.dml.DmlParser.AccumulatorAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.AddSubExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.AssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.AtomicExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BooleanAndExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BooleanNotExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BooleanOrExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BuiltinFunctionExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.CommandlineParamExpressionContext;
+import 
org.apache.sysml.parser.dml.DmlParser.CommandlinePositionExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstDoubleIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstFalseExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstIntIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstStringIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstTrueExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.DataIdExpressionContext;
+import 
org.apache.sysml.parser.dml.DmlParser.ExternalFunctionDefExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ForStatementContext;
+import 
org.apache.sysml.parser.dml.DmlParser.FunctionCallAssignmentStatementContext;
+import 
org.apache.sysml.parser.dml.DmlParser.FunctionCallMultiAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.IfStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.IfdefAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.ImportStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.IndexedExpressionContext;
+import 
org.apache.sysml.parser.dml.DmlParser.InternalFunctionDefExpressionContext;
+import 
org.apache.sysml.parser.dml.DmlParser.IterablePredicateColonExpressionContext;
+import 
org.apache.sysml.parser.dml.DmlParser.IterablePredicateSeqExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.MatrixDataTypeCheckContext;
+import org.apache.sysml.parser.dml.DmlParser.MatrixMulExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.Ml_typeContext;
+import org.apache.sysml.parser.dml.DmlParser.ModIntDivExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.MultDivExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.MultiIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ParForStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.ParameterizedExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.PathStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.PowerExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ProgramrootContext;
+import org.apache.sysml.parser.dml.DmlParser.RelationalExpressionContext;
+import 
org.apache.sysml.parser.dml.DmlParser.SimpleDataIdentifierExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.StatementContext;
+import 
org.apache.sysml.parser.dml.DmlParser.StrictParameterizedExpressionContext;
+import 
org.apache.sysml.parser.dml.DmlParser.StrictParameterizedKeyValueStringContext;
+import org.apache.sysml.parser.dml.DmlParser.TypedArgAssignContext;
+import org.apache.sysml.parser.dml.DmlParser.TypedArgNoAssignContext;
+import org.apache.sysml.parser.dml.DmlParser.UnaryExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ValueTypeContext;
+import org.apache.sysml.parser.dml.DmlParser.WhileStatementContext;
+
+/**
+ * This class is used to generate inline-able methods.
+ * It does so in two phases:
+ * - Phase 1. Rewriting phase where local variables are rewritten by adding a 
prefix.
+ * - Phase 2. Capture the body of the functions using InlineableMethods class 
+ */
+public class InlineHelper extends CommonSyntacticValidator implements 
DmlListener {
+       final static String ARG_PREFIX;
+       static {
+               Random rand = new Random();
+               ARG_PREFIX = "INTERNAL_PREFIX_" + Math.abs(rand.nextLong()) + 
"_" + Math.abs(rand.nextLong()) + "_"; 
+       }
+       public HashMap<String, InlineableMethods> inlineMap = new HashMap<>();
+       TokenStreamRewriter rewriter;
+       
+       // Set internally
+       HashSet<String> variables = new HashSet<>();
+       String currentFunction = null;
+       boolean isRewritePhase;
+       
+       public InlineHelper(CustomErrorListener errorListener, Map<String, 
String> argVals, String sourceNamespace,
+                       Set<String> prepFunctions, TokenStreamRewriter 
rewriter1) {
+               super(errorListener, argVals, sourceNamespace, prepFunctions);
+               rewriter = rewriter1;
+       }
+       
+       void setPhase(boolean isRewritePhase1) {
+               isRewritePhase = isRewritePhase1;
+       }
+       
+
+       @Override
+       public void 
enterInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {
+               currentFunction = ctx.name.getText();
+               variables.clear();
+       }
+       
+       @Override
+       public void 
exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {
+               if(!isRewritePhase) {
+                       StringBuilder sb = new StringBuilder();
+                       for(StatementContext stmt : ctx.body) {
+                               sb.append(stmt.getText());
+                               sb.append("\n");
+                       }
+                       ArrayList<String> inputArgs = new ArrayList<>(); 
+                       for(TypedArgAssignContext in : ctx.inputParams) {
+                               inputArgs.add(ARG_PREFIX + 
in.paramName.getText());
+                       }
+                       ArrayList<String> retVariables = new ArrayList<>();
+                       for(TypedArgNoAssignContext out : ctx.outputParams) {
+                               retVariables.add(ARG_PREFIX + 
out.paramName.getText());
+                       }
+                       
+                       inlineMap.put(currentFunction, new 
InlineableMethods(currentFunction, sb.toString(), variables, inputArgs, 
retVariables));
+               }
+               currentFunction = null;
+               variables.clear();
+       }
+       
+       @Override
+       public void enterIndexedExpression(IndexedExpressionContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertBefore(ctx.name, " " + ARG_PREFIX);
+               }
+       }
+       
+       @Override
+       public void exitIndexedExpression(IndexedExpressionContext ctx) {
+               if(currentFunction != null)
+                       variables.add(ctx.name.getText());
+       }
+       
+       @Override
+       public void 
enterSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertBefore(ctx.start, " " + ARG_PREFIX);
+                       rewriter.insertAfter(ctx.stop, " ");
+               }
+       }
+       
+       @Override
+       public void 
exitSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) {
+               if(currentFunction != null)
+                       variables.add(ctx.getText());
+       }
+       
+       @Override
+       public void enterForStatement(ForStatementContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertBefore(ctx.iterVar, " " + ARG_PREFIX);
+                       rewriter.insertAfter(ctx.iterVar, " ");
+               }
+       }
+       
+       @Override
+       public void enterParForStatement(ParForStatementContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertBefore(ctx.iterVar, " " + ARG_PREFIX);
+                       rewriter.insertAfter(ctx.iterVar, " ");
+               }
+       }
+       
+       @Override
+       public void exitForStatement(ForStatementContext ctx) {
+               if(currentFunction != null)
+                       variables.add(ctx.iterVar.getText());
+               if(currentFunction != null && isRewritePhase) {
+                       if(ctx.body != null && ctx.body.size() > 0)
+                       rewriter.insertBefore(ctx.body.get(0).start, "\n");
+                       rewriter.insertAfter(ctx.stop, "\n");
+               }
+       }
+       
+       @Override
+       public void exitParForStatement(ParForStatementContext ctx) {
+               if(currentFunction != null)
+                       variables.add(ctx.iterVar.getText());
+               if(currentFunction != null && isRewritePhase) {
+                       if(ctx.body != null && ctx.body.size() > 0)
+                               rewriter.insertBefore(ctx.body.get(0).start, 
"\n");
+                       rewriter.insertAfter(ctx.stop, "\n");
+               }
+       }
+       
+
+       @Override
+       protected ConvertedDMLSyntax convertToDMLSyntax(ParserRuleContext ctx, 
String namespace, String functionName,
+                       ArrayList<ParameterExpression> paramExpression, Token 
fnName) {
+               
+               return null;
+       }
+
+       @Override
+       public void 
enterAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterAddSubExpression(AddSubExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterAssignmentStatement(AssignmentStatementContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterAtomicExpression(AtomicExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterBooleanAndExpression(BooleanAndExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterBooleanNotExpression(BooleanNotExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterBooleanOrExpression(BooleanOrExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
enterBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
enterCommandlineParamExpression(CommandlineParamExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
enterCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterConstDoubleIdExpression(ConstDoubleIdExpressionContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterConstFalseExpression(ConstFalseExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterConstIntIdExpression(ConstIntIdExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterConstStringIdExpression(ConstStringIdExpressionContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterConstTrueExpression(ConstTrueExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterDataIdExpression(DataIdExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterEveryRule(ParserRuleContext arg0) {
+               
+               
+       }
+
+       @Override
+       public void 
enterExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) {
+               
+               
+       }
+       
+
+       @Override
+       public void 
enterFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
enterFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext
 ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
enterIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterIfStatement(IfStatementContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterImportStatement(ImportStatementContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
enterIterablePredicateColonExpression(IterablePredicateColonExpressionContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
enterIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterMatrixMulExpression(MatrixMulExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterMl_type(Ml_typeContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterModIntDivExpression(ModIntDivExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterMultDivExpression(MultDivExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterMultiIdExpression(MultiIdExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterParameterizedExpression(ParameterizedExpressionContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterPathStatement(PathStatementContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterPowerExpression(PowerExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterProgramroot(ProgramrootContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterRelationalExpression(RelationalExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
enterStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
enterStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterTypedArgAssign(TypedArgAssignContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterTypedArgNoAssign(TypedArgNoAssignContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterUnaryExpression(UnaryExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterValueType(ValueTypeContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void enterWhileStatement(WhileStatementContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
exitAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertAfter(ctx.stop, ";\n");
+               }
+               
+       }
+
+       @Override
+       public void exitAddSubExpression(AddSubExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitAssignmentStatement(AssignmentStatementContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertAfter(ctx.stop, ";\n");
+               }
+               
+       }
+
+       @Override
+       public void exitAtomicExpression(AtomicExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitBooleanAndExpression(BooleanAndExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitBooleanNotExpression(BooleanNotExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitBooleanOrExpression(BooleanOrExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
exitBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
exitCommandlineParamExpression(CommandlineParamExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
exitCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitConstDoubleIdExpression(ConstDoubleIdExpressionContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitConstFalseExpression(ConstFalseExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitConstIntIdExpression(ConstIntIdExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitConstStringIdExpression(ConstStringIdExpressionContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitConstTrueExpression(ConstTrueExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitDataIdExpression(DataIdExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitEveryRule(ParserRuleContext arg0) {
+               
+               
+       }
+
+       @Override
+       public void 
exitExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) {
+               
+               
+       }
+
+       
+
+       @Override
+       public void 
exitFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) 
{
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertAfter(ctx.stop, ";\n");
+               }
+               
+       }
+
+       @Override
+       public void 
exitFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext
 ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertAfter(ctx.stop, ";\n");
+               }
+               
+       }
+
+       @Override
+       public void 
exitIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertAfter(ctx.stop, ";\n");
+               }
+               
+       }
+
+       @Override
+       public void exitIfStatement(IfStatementContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       if(ctx.ifBody != null && ctx.ifBody.size() > 0)
+                               rewriter.insertBefore(ctx.ifBody.get(0).start, 
"\n");
+                       if(ctx.elseBody != null && ctx.elseBody.size() > 0)
+                               
rewriter.insertBefore(ctx.elseBody.get(0).start, "\n");
+                       rewriter.insertAfter(ctx.stop, "\n");
+               }
+               
+       }
+
+       @Override
+       public void exitImportStatement(ImportStatementContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertAfter(ctx.stop, ";\n");
+               }
+               
+       }
+
+       @Override
+       public void 
exitIterablePredicateColonExpression(IterablePredicateColonExpressionContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
exitIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitMatrixMulExpression(MatrixMulExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitMl_type(Ml_typeContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitModIntDivExpression(ModIntDivExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitMultDivExpression(MultDivExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitMultiIdExpression(MultiIdExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitParameterizedExpression(ParameterizedExpressionContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitPathStatement(PathStatementContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       rewriter.insertAfter(ctx.stop, ";\n");
+               }
+               
+       }
+
+       @Override
+       public void exitPowerExpression(PowerExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitProgramroot(ProgramrootContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitRelationalExpression(RelationalExpressionContext ctx) {
+               
+               
+       }
+       
+       @Override
+       public void 
exitStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void 
exitStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext 
ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitTypedArgAssign(TypedArgAssignContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitTypedArgNoAssign(TypedArgNoAssignContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitUnaryExpression(UnaryExpressionContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitValueType(ValueTypeContext ctx) {
+               
+               
+       }
+
+       @Override
+       public void exitWhileStatement(WhileStatementContext ctx) {
+               if(currentFunction != null && isRewritePhase) {
+                       if(ctx.body != null && ctx.body.size() > 0)
+                               rewriter.insertBefore(ctx.body.get(0).start, 
"\n");
+                       rewriter.insertAfter(ctx.stop, "\n");
+               }
+       }
+
+       @Override
+       public String falseStringLiteral() {
+               
+               return null;
+       }
+
+       @Override
+       protected Expression handleLanguageSpecificFunction(ParserRuleContext 
ctx, String functionName,
+                       ArrayList<ParameterExpression> paramExpressions) {
+               
+               return null;
+       }
+
+       @Override
+       public String namespaceResolutionOp() {
+               
+               return null;
+       }
+
+       @Override
+       public String trueStringLiteral() {
+               
+               return null;
+       }
+
+       @Override
+       public void visitErrorNode(ErrorNode arg0) {
+               
+               
+       }
+
+       @Override
+       public void visitTerminal(TerminalNode arg0) {
+               
+               
+       }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java 
b/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java
new file mode 100644
index 0000000..d3b0d11
--- /dev/null
+++ b/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.parser.dml;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Random;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+
+/** 
+ * This class is a simple container class used to hold the function to be 
inlined.
+ * It contains the function name, body and also the input and return arguments.
+ * The user invokes getInlinedDML method to get the inlined DML code.
+ */
+public class InlineableMethods {
+       ArrayList<String> _variables;
+       final String _body;
+       final String _fnName;
+       final ArrayList<String> _inputArgs;
+       final ArrayList<String> _retVariables;
+       static int CALLER_ID = 1;
+       
+       public InlineableMethods(String fnName, String body, HashSet<String> 
variables, ArrayList<String> inputArgs, ArrayList<String> retVariables) {
+               _fnName = fnName;
+               _body = body;
+               _variables = new ArrayList<String>(variables);
+               
_variables.sort(Comparator.comparing(String::length).reversed());
+               _inputArgs = inputArgs;
+               _retVariables = retVariables;
+       }
+       
+       public ArrayList<String> getLocalVariables() {
+               return _variables;
+       }
+       
+       private String _getInlinedDML(HashMap<String, String> actualArguments) {
+               String ret = _body;
+               int callerID = CALLER_ID++;
+               for(String var : _variables) {
+                       String originalVarName = 
var.substring(InlineHelper.ARG_PREFIX.length());
+                       if(actualArguments.containsKey(var)) {
+                               ret = ret.replaceAll(var, 
actualArguments.get(var));
+                       }
+                       else {
+                               // internal argument
+                               ret = ret.replaceAll(var, LOCAL_ARG_PREFIX + 
_fnName + "_" + callerID + "_" + originalVarName);
+                       }
+               }
+               return ret;
+       }
+       
+       public String getInlinedDML(ArrayList<String> actualInputArgs, 
ArrayList<String> actualRetVariables) {
+               HashMap<String, String> actualArguments = new HashMap<>();
+               if(actualInputArgs.size() != _inputArgs.size()) {
+                       throw new DMLRuntimeException("Incorrect number of 
input arguments for the function " + _fnName + ": expected " 
+                       + _inputArgs.size() + " (" + String.join(", ", 
_inputArgs) + ") but found " + actualInputArgs.size() 
+                       + " (" + String.join(", ", actualInputArgs) + ")");
+               }
+               if(actualRetVariables.size() != _retVariables.size()) {
+                       throw new DMLRuntimeException("Incorrect number of 
return variables for the function " + _fnName + ": expected " 
+                       + _retVariables.size() + " (" + String.join(", ", 
_retVariables) + ") but found " + actualRetVariables.size()
+                       + " (" + String.join(", ", actualRetVariables) + ")");
+               }
+               for(int i = 0; i < _inputArgs.size(); i++) {
+                       actualArguments.put(_inputArgs.get(i), 
actualInputArgs.get(i));
+               }
+               for(int i = 0; i < _retVariables.size(); i++) {
+                       actualArguments.put(_retVariables.get(i), 
actualRetVariables.get(i));
+               }
+               return _getInlinedDML(actualArguments);
+       }
+       
+       static final String LOCAL_ARG_PREFIX;
+       static {
+               Random rand = new Random();
+               LOCAL_ARG_PREFIX = "LOCAL_" + Math.abs(rand.nextLong()) + "_" + 
Math.abs(rand.nextLong());
+//             LOCAL_ARG_PREFIX = "LOCAL_";
+       }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index 03bc3b3..15dd23e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -307,7 +307,14 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
        }
        
        public MatrixCharacteristics getMatrixCharacteristics() {
-               return _metaData.getMatrixCharacteristics();
+               MatrixCharacteristics mc = _metaData.getMatrixCharacteristics();
+               if(mc.getRowsPerBlock() == -1) {
+                       mc.setRowsPerBlock(OptimizerUtils.DEFAULT_BLOCKSIZE);
+               }
+               if(mc.getColsPerBlock() == -1) {
+                       mc.setColsPerBlock(OptimizerUtils.DEFAULT_BLOCKSIZE);
+               }
+               return mc;
        }
 
        public abstract void refreshMetaData();

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
index 6772b4a..a08d4fd 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -298,7 +298,7 @@ public class GPUMemoryManager {
                                evictOrClear(sizeBasedUnlockedGPUObjects.get(), 
opcode);
                                A = cudaMallocNoWarn(tmpA, size, null);
                                if(A == null)
-                                       LOG.warn("cudaMalloc failed after 
clearing/evicting based on size.");
+                                       LOG.debug("cudaMalloc failed after 
clearing/evicting based on size.");
                                if(ConfigurationManager.isStatistics()) {
                                        long totalTime = System.nanoTime() - t0;
                                        
GPUStatistics.cudaEvictTime.add(totalTime);

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py 
b/src/main/python/systemml/mllearn/estimators.py
index fbcd3e2..8a100b4 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -924,7 +924,7 @@ class Caffe2DML(BaseSystemMLClassifier):
             self.estimator.setWeightsToIgnore(ignore_weights)
 
     def set(self, debug=None, train_algo=None, test_algo=None, 
parallel_batches=None,
-            output_activations=None, perform_one_hot_encoding=None, 
parfor_parameters=None):
+            output_activations=None, perform_one_hot_encoding=None, 
parfor_parameters=None, inline_nn_library=None):
         """
         Set input to Caffe2DML
 
@@ -937,9 +937,12 @@ class Caffe2DML(BaseSystemMLClassifier):
         output_activations: (developer flag) directory to output activations 
of each layer as csv while prediction. To be used only in batch mode (default: 
None)
         perform_one_hot_encoding: should perform one-hot encoding in DML using 
table function (default: False)
         parfor_parameters: dictionary for parfor parameters when using 
allreduce-style algorithms (default: "")
+        inline_nn_library: whether to inline the NN library when generating 
DML using Caffe2DML (default: False)
         """
         if debug is not None:
             self.estimator.setInput("$debug", str(debug).upper())
+        if inline_nn_library is not None:
+            self.estimator.setInput("$inline_nn_library", 
str(inline_nn_library).upper())
         if train_algo is not None:
             self.estimator.setInput("$train_algo", str(train_algo).lower())
         if test_algo is not None:

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala 
b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index 26e554f..8ddb1fe 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -119,8 +119,9 @@ object Caffe2DML {
   val LOG = LogFactory.getLog(classOf[Caffe2DML].getName())
   // ------------------------------------------------------------------------
   val USE_PLUS_EQ = true
-  def layerDir = "nn/layers/"
-  def optimDir = "nn/optim/"
+  def nnDir = "nn/"
+  def layerDir = nnDir + "layers/"
+  def optimDir = nnDir + "optim/"
 
   // Naming conventions:
   val X    = "X"; val y        = "y"; val batchSize = "BATCH_SIZE"; val 
numImages = "num_images"; val numValidationImages = "num_validation"
@@ -159,6 +160,7 @@ object Caffe2DML {
   val BATCH_ALGORITHM = "batch"
   val ALLREDUCE_ALGORITHM = "allreduce"
   val ALLREDUCE_PARALLEL_BATCHES_ALGORITHM = "allreduce_parallel_batches"
+  var INLINE_NN_LIBRARY = false
 }
 
 class Caffe2DML(val sc: SparkContext,
@@ -312,6 +314,7 @@ class Caffe2DML(val sc: SparkContext,
 
     // Flags passed by user
     val DEBUG_TRAINING = if (inputs.containsKey("$debug")) 
inputs.get("$debug").toLowerCase.toBoolean else false
+    Caffe2DML.INLINE_NN_LIBRARY = if 
(inputs.containsKey("$inline_nn_library")) 
inputs.get("$inline_nn_library").toLowerCase.toBoolean else false
     assign(tabDMLScript, "debug", if (DEBUG_TRAINING) "TRUE" else "FALSE")
     setDebugFlags(DEBUG_TRAINING)
 
@@ -721,6 +724,7 @@ class Caffe2DMLModel(val numClasses: String, val sc: 
SparkContext, val solver: C
     reset // Reset the state of DML generator for training script.
 
     val DEBUG_PREDICTION = if (estimator.inputs.containsKey("$debug")) 
estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+    Caffe2DML.INLINE_NN_LIBRARY = if 
(estimator.inputs.containsKey("$inline_nn_library")) 
estimator.inputs.get("$inline_nn_library").toLowerCase.toBoolean else false
     assign(tabDMLScript, "debug", if (DEBUG_PREDICTION) "TRUE" else "FALSE")
     estimator.setDebugFlags(DEBUG_PREDICTION)
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala 
b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index d664f6e..b290983 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -29,6 +29,7 @@ import caffe.Caffe.EltwiseParameter.EltwiseOp
 import org.apache.sysml.runtime.DMLRuntimeException;
 import java.util.ArrayList
 import caffe.Caffe.PoolingParameter.PoolMethod
+import scala.collection.JavaConverters._
 
 trait CaffeLayer extends BaseDMLGenerator {
   // -------------------------------------------------
@@ -125,7 +126,7 @@ trait CaffeLayer extends BaseDMLGenerator {
   // The layers that have a corresponding dml script call this method.
   // Assumption: the first variable of resultVariables is always dX
   def invokeBackward(dmlScript: StringBuilder, outSuffix: String, 
resultVariables: List[String], arguments: String*): Unit = {
-    invoke(dmlScript, sourceFileName + "::", resultVariables.map(_ + 
outSuffix), "backward", arguments.toList, false)
+    Utils.invoke(Caffe2DML.layerDir, dmlScript, sourceFileName + "::", 
resultVariables.map(_ + outSuffix), "backward", arguments.toList, false)
     val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => 
net.getCaffeLayer(l).id)
     dmlScript.append("; ")
     bottomLayerIDs.map(bottomLayerID => dmlScript.append(dX(bottomLayerID) + 
outSuffix + " = " + resultVariables(0) + outSuffix + "; "))
@@ -140,6 +141,13 @@ trait CaffeLayer extends BaseDMLGenerator {
     dmlScript.append("\n")
   }
   // 
--------------------------------------------------------------------------------------
+  
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: 
List[String], functionName: String, arguments: List[String]): Unit =
+    Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, 
functionName, arguments, true)
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: 
List[String], functionName: String, appendNewLine: Boolean, arguments: 
String*): Unit =
+    Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, 
functionName, arguments.toList, appendNewLine)
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: 
List[String], functionName: String, arguments: String*): Unit =
+    Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, 
functionName, arguments.toList, true)
 }
 
 trait IsLossLayer extends CaffeLayer {
@@ -1603,4 +1611,4 @@ class DeConvolution(val param: LayerParameter, val id: 
Int, val net: CaffeNetwor
     if (convParam.hasPadW) convParam.getPadW.toString
     else if (convParam.getPadCount > 0) convParam.getPad(0).toString
     else "0"
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala 
b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
index 8559c60..da963b9 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
@@ -67,10 +67,12 @@ trait CaffeSolver {
       val hasDecayMult = layer.param.getParamList != null && 
layer.param.getParamList.size >= 1 && 
layer.param.getParamList.get(0).hasDecayMult
       val newLambda = if(hasDecayMult) 
layer.param.getParamList.get(0).getDecayMult * lambda else lambda
       
-      dmlScript.append("\t").append(layer.dWeight + "_reg = " + 
regularizationSource + "::backward(" + layer.weight + ", " + newLambda + ")\n")
+      Utils.invoke(Caffe2DML.layerDir, dmlScript, regularizationSource + "::", 
List[String](layer.dWeight + "_reg"), "backward", 
+          List[String](layer.weight, "" + newLambda), true)
       dmlScript.append("\t").append(layer.dWeight + " = " + layer.dWeight + " 
+ " + layer.dWeight + "_reg\n")
       if(layer.shouldUpdateExtraWeight) {
-        dmlScript.append("\t").append(layer.dExtraWeight + "_reg = " + 
regularizationSource + "::backward(" + layer.extraWeight + ", " + newLambda + 
")\n")
+        Utils.invoke(Caffe2DML.layerDir, dmlScript, regularizationSource + 
"::", List[String](layer.dExtraWeight + "_reg"), "backward", 
+          List[String](layer.extraWeight, "" + newLambda), true)
         dmlScript.append("\t").append(layer.dExtraWeight + " = " + 
layer.dExtraWeight + " + " + layer.dExtraWeight + "_reg\n")
       }
     }
@@ -339,32 +341,19 @@ class Nesterov(regularizationType:String = "L2", lambda: 
Double = 5e-04, momentu
    *      input v.
    */
   def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
-    val fn            = if (Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else 
"sgd_nesterov::update"
-    val lastParameter = if (Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
+    
     if (!Caffe2DML.USE_NESTEROV_UDF) {
       regularization_update(regularizationType, lambda, dmlScript, layer)
     }
     if (layer.shouldUpdateWeight)
-      dmlScript
-        .append("\t")
-        .append(
-          "[" + commaSep(layer.weight, layer.weight + "_v") + "] " +
-          "= " + fn + "(" + commaSep(layer.weight, layer.dWeight, 
getWeightLr(layer), momentum.toString, layer.weight + "_v") + lastParameter + 
")\n"
-        )
+      Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", 
List[String](layer.weight, layer.weight + "_v"), "update", 
+          List[String](layer.weight, layer.dWeight, getWeightLr(layer), 
momentum.toString, layer.weight + "_v"), true)
     if (layer.shouldUpdateExtraWeight)
-      dmlScript
-        .append("\t")
-        .append(
-          "[" + commaSep(layer.extraWeight, layer.extraWeight + "_v") + "] " +
-          "= " + fn + "(" + commaSep(layer.extraWeight, layer.dExtraWeight, 
getWeightLr(layer), momentum.toString, layer.extraWeight + "_v") + 
lastParameter + ")\n"
-        )
+      Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", 
List[String](layer.extraWeight, layer.extraWeight + "_v"), "update", 
+          List[String](layer.extraWeight, layer.dExtraWeight, 
getWeightLr(layer), momentum.toString, layer.extraWeight + "_v"), true)
     if (layer.shouldUpdateBias)
-      dmlScript
-        .append("\t")
-        .append(
-          "[" + commaSep(layer.bias, layer.bias + "_v") + "] " +
-          "= " + fn + "(" + commaSep(layer.bias, layer.dBias, 
getBiasLr(layer), momentum.toString, layer.bias + "_v") + lastParameter + ")\n"
-        )
+      Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", 
List[String](layer.bias, layer.bias + "_v"), "update", 
+          List[String](layer.bias, layer.dBias, getBiasLr(layer), 
momentum.toString, layer.bias + "_v"), true)
   }
   def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
     if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_v = 
sgd_nesterov::init(" + layer.weight + ")\n")
@@ -372,4 +361,4 @@ class Nesterov(regularizationType:String = "L2", lambda: 
Double = 5e-04, momentu
     if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_v = 
sgd_nesterov::init(" + layer.bias + ")\n")
   }
   def sourceFileName: String = "sgd_nesterov"
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala 
b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 60396f1..59c75ad 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -75,39 +75,6 @@ trait BaseDMLGenerator {
     sum(dmlScript, rhsVars)
     dmlScript.append("\n")
   }
-  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: 
List[String], functionName: String, arguments: List[String]): Unit =
-    invoke(dmlScript, namespace1, returnVariables, functionName, arguments, 
true)
-  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: 
List[String], functionName: String, arguments: List[String], appendNewLine: 
Boolean): Unit = {
-    if (returnVariables.length == 0) throw new 
DMLRuntimeException("User-defined functions should have atleast one return 
value")
-    if (returnVariables.length > 1) dmlScript.append("[")
-    dmlScript.append(returnVariables(0))
-    if (returnVariables.length > 1) {
-      for (i <- 1 until returnVariables.length) {
-        dmlScript.append(",").append(returnVariables(i))
-      }
-      dmlScript.append("]")
-    }
-    dmlScript.append(" = ")
-    dmlScript.append(namespace1)
-    dmlScript.append(functionName)
-    dmlScript.append("(")
-    if (arguments != null) {
-      if (arguments.length != 0)
-        dmlScript.append(arguments(0))
-      if (arguments.length > 1) {
-        for (i <- 1 until arguments.length) {
-          dmlScript.append(",").append(arguments(i))
-        }
-      }
-    }
-    dmlScript.append(")")
-    if (appendNewLine)
-      dmlScript.append("\n")
-  }
-  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: 
List[String], functionName: String, appendNewLine: Boolean, arguments: 
String*): Unit =
-    invoke(dmlScript, namespace1, returnVariables, functionName, 
arguments.toList, appendNewLine)
-  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: 
List[String], functionName: String, arguments: String*): Unit =
-    invoke(dmlScript, namespace1, returnVariables, functionName, 
arguments.toList, true)
   def rightIndexing(dmlScript: StringBuilder, lhsVar:String, rhsVar: String, 
rl: String, ru: String, cl: String=null, cu: String=null): StringBuilder = {
     dmlScript.append(lhsVar).append(" = ").append(rhsVar).append("[")
     if (rl != null && ru != null) dmlScript.append(rl).append(":").append(ru)
@@ -279,6 +246,7 @@ trait DMLGenerator extends SourceDMLGenerator with 
NextBatchGenerator {
     // Append source statements for layers as well as solver
     source(net, solver, if (isTraining) Array[String]("l1_reg") else null)
     source(net, solver, if (isTraining) Array[String]("l2_reg") else null)
+    source(dmlScript, numTabs, "util", Caffe2DML.nnDir)
 
     if (isTraining) {
       // Append external built-in function headers:
@@ -346,4 +314,4 @@ trait DMLGenerator extends SourceDMLGenerator with 
NextBatchGenerator {
 
   def updateMeanVarianceForBatchNorm(net: CaffeNetwork, value: Boolean): Unit =
     
net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var
 = value)
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Utils.scala 
b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
index 63aaf91..da53f65 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
@@ -39,6 +39,11 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.api.mlcontext.MLContext
 import org.apache.spark.SparkContext
 import org.apache.spark.api.java.JavaSparkContext
+import org.apache.sysml.parser.ParserWrapper
+import org.apache.sysml.parser.dml.DMLParserWrapper
+import org.apache.sysml.parser.dml.InlineableMethods
+import java.util.ArrayList
+import scala.collection.JavaConverters._
 
 object Utils {
   // 
---------------------------------------------------------------------------------------------
@@ -64,6 +69,57 @@ object Utils {
       line = bufReader.readLine()
     }
   }
+  
+  def readDMLScript(fileName:String):String = 
ParserWrapper.readDMLScript(fileName, Caffe2DML.LOG)
+  val inlineableMethods = new java.util.HashMap[String, 
java.util.HashMap[String, InlineableMethods]]()
+  def getInlineableMethod(sourceFilePath:String, namespace:String, 
fnName:String):InlineableMethods = {
+    if(inlineableMethods.contains(namespace))
+      return inlineableMethods.get(namespace).get(fnName)
+    else {
+      val ret = new DMLParserWrapper().getInlineableMethods(sourceFilePath, 
null, namespace, null)
+      inlineableMethods.put(namespace, ret)
+      return ret.get(fnName)
+    }
+  }
+  
+  def invoke(dir:String, dmlScript: StringBuilder, namespace1: String, 
returnVariables: List[String], functionName: String, arguments: List[String], 
appendNewLine: Boolean): Unit = {
+    if(Caffe2DML.INLINE_NN_LIBRARY) {
+      // Caffe2DML.layerDir
+      // For now, donot inline recursively
+      val sourceFileName = if(namespace1.endsWith("::")) 
namespace1.substring(0, namespace1.length() - 2) else namespace1
+      val method = getInlineableMethod(dir + sourceFileName + ".dml", 
namespace1, functionName)
+      val generatedDML = method.getInlinedDML(new 
ArrayList[String](arguments.asJava), new 
ArrayList[String](returnVariables.asJava))
+      dmlScript.append(generatedDML)
+      dmlScript.append("\n")
+      //System.out.println(generatedDML)
+      return
+    }
+    if (returnVariables.length == 0) throw new 
DMLRuntimeException("User-defined functions should have atleast one return 
value")
+    if (returnVariables.length > 1) dmlScript.append("[")
+    dmlScript.append(returnVariables(0))
+    if (returnVariables.length > 1) {
+      for (i <- 1 until returnVariables.length) {
+        dmlScript.append(",").append(returnVariables(i))
+      }
+      dmlScript.append("]")
+    }
+    dmlScript.append(" = ")
+    dmlScript.append(namespace1)
+    dmlScript.append(functionName)
+    dmlScript.append("(")
+    if (arguments != null) {
+      if (arguments.length != 0)
+        dmlScript.append(arguments(0))
+      if (arguments.length > 1) {
+        for (i <- 1 until arguments.length) {
+          dmlScript.append(",").append(arguments(i))
+        }
+      }
+    }
+    dmlScript.append(")")
+    if (appendNewLine)
+      dmlScript.append("\n")
+  }
 
   // 
---------------------------------------------------------------------------------------------
   def parseSolver(solverFilePath: String): CaffeSolver = 
parseSolver(readCaffeSolver(solverFilePath))
@@ -324,4 +380,4 @@ class Utils {
   def saveCaffeModelFile(sc: JavaSparkContext, deployFilePath: String, 
caffeModelFilePath: String, outputDirectory: String, format: String): Unit =
     Utils.saveCaffeModelFile(sc, deployFilePath, caffeModelFilePath, 
outputDirectory, format)
 
-}
+}
\ No newline at end of file

Reply via email to