[HOTFIX] Various fixes size propagation and batch_norm2d integration

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7d936cf0
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7d936cf0
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7d936cf0

Branch: refs/heads/master
Commit: 7d936cf0c1f4061c1b611378cd2e07a23372cfc0
Parents: 153fd89
Author: Matthias Boehm <[email protected]>
Authored: Fri Jun 1 22:23:52 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Jun 1 22:23:52 2018 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/FunctionOp.java  | 15 +++------------
 .../sysml/hops/rewrite/RewriteConstantFolding.java   | 15 ++++++++++-----
 .../sysml/parser/BuiltinFunctionExpression.java      |  4 ++--
 .../java/org/apache/sysml/parser/DataExpression.java |  5 ++---
 .../org/apache/sysml/parser/ForStatementBlock.java   |  6 ++++--
 .../sysml/runtime/matrix/data/LibMatrixCUDA.java     |  3 ---
 6 files changed, 21 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java 
b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index da13cd1..428f357 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -280,10 +280,8 @@ public class FunctionOp extends Hop
        protected ExecType optFindExecType() 
        {
                checkAndSetForcedPlatform();
-               ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? 
ExecType.SPARK : ExecType.MR;
                
                if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) {
-                       
                        // check if there is sufficient memory to execute this 
function
                        if( 
getFunctionName().equalsIgnoreCase("transformencode") ) {
                                _etype = ((_etypeForced==ExecType.SPARK 
@@ -291,19 +289,12 @@ public class FunctionOp extends Hop
                                                && 
OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP);
                        }
                        else if( getFunctionName().equalsIgnoreCase("lstm")) {
-                               if(DMLScript.USE_ACCELERATOR)
-                                       _etype = ExecType.GPU;
-                               else
+                               if(!DMLScript.USE_ACCELERATOR)
                                        throw new RuntimeException("The 
function " + getFunctionName() + " is only supported on GPU.");
+                               _etype = ExecType.GPU;
                        }
                        else if( 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
-                               if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
-                                       _etype = findExecTypeByMemEstimate();
-                               }
-                               else {
-                                       _etype = ExecType.CP;
-                               }
-                               _etype = _etype == REMOTE ?  ExecType.CP : 
_etype; // batch_norm2d and batch_norm2d_backward are  not supported on Spark
+                               _etype = DMLScript.USE_ACCELERATOR ? 
ExecType.GPU : ExecType.CP;
                        }
                        else {
                                // Since the memory estimate is only 
conservative, do not throw

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
index 304fbc3..a153fb1 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
@@ -104,16 +104,21 @@ public class RewriteConstantFolding extends HopRewriteRule
                        catch(Exception ex) {
                                LOG.error("Failed to execute constant folding 
instructions. No abort.", ex);
                        }
-                       
+               }
+               //fold nrow as precondition for further constant folding 
+               else if( HopRewriteUtils.isUnary(root, OpOp1.NROW) && 
root.getInput().get(0).rowsKnown() ) {
+                       literal = new 
LiteralOp(root.getInput().get(0).getDim1());
+               }
+               //fold ncol as precondition for further constant folding 
+               else if( HopRewriteUtils.isUnary(root, OpOp1.NCOL) && 
root.getInput().get(0).colsKnown() ) {
+                       literal = new 
LiteralOp(root.getInput().get(0).getDim2());
                }
                //fold conjunctive predicate if at least one input is literal 
'false'
-               else if( isApplicableFalseConjunctivePredicate(root) )
-               {
+               else if( isApplicableFalseConjunctivePredicate(root) ) {
                        literal = new LiteralOp(false);
                }
                //fold disjunctive predicate if at least one input is literal 
'true'
-               else if( isApplicableTrueDisjunctivePredicate(root) )
-               {
+               else if( isApplicableTrueDisjunctivePredicate(root) ) {
                        literal = new LiteralOp(true);
                }
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 42b8329..eb9e2c3 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -99,7 +99,7 @@ public class BuiltinFunctionExpression extends DataIdentifier
        }
        
        public Expression getSixthExpr() {
-               return (_args.length >= 5 ? _args[4] : null);
+               return (_args.length >= 6 ? _args[5] : null);
        }
 
        public Expression[] getAllExpr(){
@@ -107,7 +107,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
        }
        
        public Expression getExpr(int i) {
-               return (_args.length > i ? _args[i] : null);    
+               return (_args.length > i ? _args[i] : null);
        }
        
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/parser/DataExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DataExpression.java 
b/src/main/java/org/apache/sysml/parser/DataExpression.java
index 7218023..93379fa 100644
--- a/src/main/java/org/apache/sysml/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysml/parser/DataExpression.java
@@ -535,10 +535,9 @@ public class DataExpression extends DataIdentifier
                }       
        
                //general data expression constant propagation
-               if( !conditional ) {
+               if( !conditional )
                        performConstantPropagationRand( currConstVars );
-                       performConstantPropagationReadWrite( currConstVars );
-               }
+               performConstantPropagationReadWrite( currConstVars );
                
                // check if data parameter of matrix is scalar or matrix -- if 
scalar, use Rand instead
                Expression dataParam1 = getVarParam(RAND_DATA);         

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/parser/ForStatementBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/ForStatementBlock.java 
b/src/main/java/org/apache/sysml/parser/ForStatementBlock.java
index dd6f7a4..62f59dd 100644
--- a/src/main/java/org/apache/sysml/parser/ForStatementBlock.java
+++ b/src/main/java/org/apache/sysml/parser/ForStatementBlock.java
@@ -89,7 +89,8 @@ public class ForStatementBlock extends StatementBlock
                //validate body
                _dmlProg = dmlProg;
                for(StatementBlock sb : body) {
-                       ids = sb.validate(dmlProg, ids, constVars, true);
+                       ids = sb.validate(dmlProg, ids, constVars,
+                               !(this instanceof ParForStatementBlock));
                        constVars = sb.getConstOut();
                }
                
@@ -165,7 +166,8 @@ public class ForStatementBlock extends StatementBlock
                        //validate body
                        _dmlProg = dmlProg;
                        for(StatementBlock sb : body) {
-                               ids = sb.validate(dmlProg, ids, constVars, 
true);
+                               ids = sb.validate(dmlProg, ids, constVars,
+                                       !(this instanceof 
ParForStatementBlock));
                                constVars = sb.getConstOut();
                        }
                        if (!body.isEmpty()){

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 8b5043f..e4fdc96 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -21,7 +21,6 @@ package org.apache.sysml.runtime.matrix.data;
 
 import static jcuda.jcublas.cublasOperation.CUBLAS_OP_N;
 import static jcuda.jcublas.cublasOperation.CUBLAS_OP_T;
-import static jcuda.runtime.JCuda.cudaMalloc;
 import static jcuda.runtime.JCuda.cudaMemcpy;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
@@ -89,10 +88,8 @@ import jcuda.jcublas.cublasHandle;
 import jcuda.jcublas.cublasOperation;
 import jcuda.jcublas.cublasSideMode;
 import jcuda.jcusparse.cusparseAction;
-import jcuda.jcusparse.cusparseDirection;
 import jcuda.jcusparse.cusparseHandle;
 import jcuda.jcusparse.cusparseIndexBase;
-import jcuda.jcusparse.cusparseMatDescr;
 
 /**
  * All CUDA kernels and library calls are redirected through this class

Reply via email to