[HOTFIX] Various fixes size propagation and batch_norm2d integration Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7d936cf0 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7d936cf0 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7d936cf0
Branch: refs/heads/master Commit: 7d936cf0c1f4061c1b611378cd2e07a23372cfc0 Parents: 153fd89 Author: Matthias Boehm <[email protected]> Authored: Fri Jun 1 22:23:52 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Jun 1 22:23:52 2018 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/FunctionOp.java | 15 +++------------ .../sysml/hops/rewrite/RewriteConstantFolding.java | 15 ++++++++++----- .../sysml/parser/BuiltinFunctionExpression.java | 4 ++-- .../java/org/apache/sysml/parser/DataExpression.java | 5 ++--- .../org/apache/sysml/parser/ForStatementBlock.java | 6 ++++-- .../sysml/runtime/matrix/data/LibMatrixCUDA.java | 3 --- 6 files changed, 21 insertions(+), 27 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/hops/FunctionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java index da13cd1..428f357 100644 --- a/src/main/java/org/apache/sysml/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java @@ -280,10 +280,8 @@ public class FunctionOp extends Hop protected ExecType optFindExecType() { checkAndSetForcedPlatform(); - ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR; if ( getFunctionType() == FunctionType.MULTIRETURN_BUILTIN ) { - // check if there is sufficient memory to execute this function if( getFunctionName().equalsIgnoreCase("transformencode") ) { _etype = ((_etypeForced==ExecType.SPARK @@ -291,19 +289,12 @@ public class FunctionOp extends Hop && OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP); } else if( getFunctionName().equalsIgnoreCase("lstm")) { - if(DMLScript.USE_ACCELERATOR) - _etype = ExecType.GPU; - else + if(!DMLScript.USE_ACCELERATOR) throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU."); + _etype = ExecType.GPU; } else if( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) { - if ( OptimizerUtils.isMemoryBasedOptLevel() ) { - _etype = findExecTypeByMemEstimate(); - } - else { - _etype = ExecType.CP; - } - _etype = _etype == REMOTE ? ExecType.CP : _etype; // batch_norm2d and batch_norm2d_backward are not supported on Spark + _etype = DMLScript.USE_ACCELERATOR ? ExecType.GPU : ExecType.CP; } else { // Since the memory estimate is only conservative, do not throw http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java index 304fbc3..a153fb1 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java @@ -104,16 +104,21 @@ public class RewriteConstantFolding extends HopRewriteRule catch(Exception ex) { LOG.error("Failed to execute constant folding instructions. No abort.", ex); } - + } + //fold nrow as precondition for further constant folding + else if( HopRewriteUtils.isUnary(root, OpOp1.NROW) && root.getInput().get(0).rowsKnown() ) { + literal = new LiteralOp(root.getInput().get(0).getDim1()); + } + //fold ncol as precondition for further constant folding + else if( HopRewriteUtils.isUnary(root, OpOp1.NCOL) && root.getInput().get(0).colsKnown() ) { + literal = new LiteralOp(root.getInput().get(0).getDim2()); } //fold conjunctive predicate if at least one input is literal 'false' - else if( isApplicableFalseConjunctivePredicate(root) ) - { + else if( isApplicableFalseConjunctivePredicate(root) ) { literal = new LiteralOp(false); } //fold disjunctive predicate if at least one input is literal 'true' - else if( isApplicableTrueDisjunctivePredicate(root) ) - { + else if( isApplicableTrueDisjunctivePredicate(root) ) { literal = new LiteralOp(true); } http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index 42b8329..eb9e2c3 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -99,7 +99,7 @@ public class BuiltinFunctionExpression extends DataIdentifier } public Expression getSixthExpr() { - return (_args.length >= 5 ? _args[4] : null); + return (_args.length >= 6 ? _args[5] : null); } public Expression[] getAllExpr(){ @@ -107,7 +107,7 @@ public class BuiltinFunctionExpression extends DataIdentifier } public Expression getExpr(int i) { - return (_args.length > i ? _args[i] : null); + return (_args.length > i ? _args[i] : null); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/parser/DataExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DataExpression.java b/src/main/java/org/apache/sysml/parser/DataExpression.java index 7218023..93379fa 100644 --- a/src/main/java/org/apache/sysml/parser/DataExpression.java +++ b/src/main/java/org/apache/sysml/parser/DataExpression.java @@ -535,10 +535,9 @@ public class DataExpression extends DataIdentifier } //general data expression constant propagation - if( !conditional ) { + if( !conditional ) performConstantPropagationRand( currConstVars ); - performConstantPropagationReadWrite( currConstVars ); - } + performConstantPropagationReadWrite( currConstVars ); // check if data parameter of matrix is scalar or matrix -- if scalar, use Rand instead Expression dataParam1 = getVarParam(RAND_DATA); http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/parser/ForStatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ForStatementBlock.java b/src/main/java/org/apache/sysml/parser/ForStatementBlock.java index dd6f7a4..62f59dd 100644 --- a/src/main/java/org/apache/sysml/parser/ForStatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/ForStatementBlock.java @@ -89,7 +89,8 @@ public class ForStatementBlock extends StatementBlock //validate body _dmlProg = dmlProg; for(StatementBlock sb : body) { - ids = sb.validate(dmlProg, ids, constVars, true); + ids = sb.validate(dmlProg, ids, constVars, + !(this instanceof ParForStatementBlock)); constVars = sb.getConstOut(); } @@ -165,7 +166,8 @@ public class ForStatementBlock extends StatementBlock //validate body _dmlProg = dmlProg; for(StatementBlock sb : body) { - ids = sb.validate(dmlProg, ids, constVars, true); + ids = sb.validate(dmlProg, ids, constVars, + !(this instanceof ParForStatementBlock)); constVars = sb.getConstOut(); } if (!body.isEmpty()){ http://git-wip-us.apache.org/repos/asf/systemml/blob/7d936cf0/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java index 8b5043f..e4fdc96 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java @@ -21,7 +21,6 @@ package org.apache.sysml.runtime.matrix.data; import static jcuda.jcublas.cublasOperation.CUBLAS_OP_N; import static jcuda.jcublas.cublasOperation.CUBLAS_OP_T; -import static jcuda.runtime.JCuda.cudaMalloc; import static jcuda.runtime.JCuda.cudaMemcpy; import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice; import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost; @@ -89,10 +88,8 @@ import jcuda.jcublas.cublasHandle; import jcuda.jcublas.cublasOperation; import jcuda.jcublas.cublasSideMode; import jcuda.jcusparse.cusparseAction; -import jcuda.jcusparse.cusparseDirection; import jcuda.jcusparse.cusparseHandle; import jcuda.jcusparse.cusparseIndexBase; -import jcuda.jcusparse.cusparseMatDescr; /** * All CUDA kernels and library calls are redirected through this class
