Repository: systemml
Updated Branches:
  refs/heads/master f702c03be -> 0529350a3


[SYSTEMML-540] Added an assert builtin function

- Assert function halts the execution of DML program if the boolean argument 
doesnot evaluate to TRUE.
- Like stop, assert is not supported inside a parfor.
- Caffe2DML inserts assert function in non-parfor cases for affine layer.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0529350a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0529350a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0529350a

Branch: refs/heads/master
Commit: 0529350a3dfa665bed040da255388f513b5868d2
Parents: f702c03
Author: Niketan Pansare <[email protected]>
Authored: Fri Jan 19 14:46:26 2018 -0800
Committer: Niketan Pansare <[email protected]>
Committed: Fri Jan 19 14:46:26 2018 -0800

----------------------------------------------------------------------
 docs/dml-language-reference.md                  |  1 +
 src/main/java/org/apache/sysml/hops/Hop.java    |  4 +-
 .../java/org/apache/sysml/hops/UnaryOp.java     |  2 +-
 .../codegen/opt/PlanSelectionFuseCostBased.java |  1 +
 .../opt/PlanSelectionFuseCostBasedV2.java       |  1 +
 .../hops/rewrite/RewriteConstantFolding.java    |  3 +-
 .../RewriteRemoveDanglingParentReferences.java  |  1 +
 .../java/org/apache/sysml/lops/UnaryCP.java     |  5 +-
 .../org/apache/sysml/parser/DMLTranslator.java  |  7 ++
 .../sysml/parser/ParForStatementBlock.java      |  4 +
 .../org/apache/sysml/parser/PrintStatement.java |  7 +-
 .../org/apache/sysml/parser/StatementBlock.java |  5 +-
 .../sysml/parser/dml/DmlSyntacticValidator.java |  1 +
 .../instructions/CPInstructionParser.java       |  1 +
 .../cp/UnaryScalarCPInstruction.java            |  7 ++
 .../org/apache/sysml/api/dl/Caffe2DML.scala     |  2 +
 .../org/apache/sysml/api/dl/CaffeLayer.scala    | 57 ++-----------
 .../functions/misc/AssertExpressionTest.java    | 87 ++++++++++++++++++++
 .../functions/misc/AssertExpressionTest1.dml    | 22 +++++
 19 files changed, 161 insertions(+), 57 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/docs/dml-language-reference.md
----------------------------------------------------------------------
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index a0cc0f7..5fb9de5 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -1542,6 +1542,7 @@ append() | Append a string to another string separated by 
"\n" <br/> Limitation:
 toString() | Formats a Matrix or Frame object into a string. <br/> "rows" & 
"cols" : number of rows and columns to print<br/> "decimal" : number of digits 
after the decimal<br/>"sparse" : set to TRUE to print Matrix object in sparse 
format, i.e. _RowIndex_ _ColIndex_ _Value_<br/>"sep" and "linesep" : 
inter-element separator and the line separator strings| Input : (&lt;matrix&gt; 
or &lt;frame&gt;,<br/> &nbsp;&nbsp;rows=100,<br/> &nbsp;&nbsp;cols=100,<br/> 
&nbsp;&nbsp;decimal=3,<br/> &nbsp;&nbsp;sparse=FALSE,<br/> &nbsp;&nbsp;sep=" 
",<br/> &nbsp;&nbsp;linesep="\n") <br/> Output: &lt;string&gt; | X = 
matrix(seq(1, 9), rows=3, cols=3)<br/>str = toString(X, sep=" \| ") <br/><br/>F 
= as.frame(X)<br/>print(toString(F, rows=2, cols=2))
 print() | Prints a scalar variable. The print() function allows printf-style 
formatting by optionally allowing multiple arguments, where the first argument 
is the string that specifies the formatting and the additional arguments are 
the arguments to format. | Input: &lt;scalar&gt;<br/>or<br/>&lt;string, 
args...&gt; | print("hello") <br/> print("hello" + "world") <br/> print("value 
of x is " + x ) 
<br/><br/>a='hello';<br/>b=3;<br/>c=4.5;<br/>d=TRUE;<br/>print('%s %d %f %b', 
a, b, c, d); 
<br/><br/>a='hello';<br/>b='goodbye';<br/>c=4;<br/>d=3;<br/>e=3.0;<br/>f=5.0;<br/>g=FALSE;<br/>print('%s
 %d %f %b', (a+b), (c-d), (e*f), !g);
 stop() | Halts the execution of DML program by printing the message that is 
passed in as the argument. <br/> Note that the use of stop() is not allowed 
inside a parfor loop. |  Input: (&lt;scalar&gt;) | stop("Inputs to DML program 
are invalid") <br/> stop("Class labels must be either -1 or +1")
+assert() | Halts the execution of DML program if the boolean argument doesnot 
evaluate to TRUE. <br/> Note that the use of assert() is not allowed inside a 
parfor loop. |  Input: (&lt;scalar of type boolean&gt;) | assert(1 != 2)
 order() | Sort a column of the matrix X in decreasing/increasing order and 
return either index (index.return=TRUE) or data (index.return=FALSE). | Input: 
(target=X, by=column, decreasing, index.return) | order(X, by=1, 
decreasing=FALSE, index.return=FALSE)
 
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index a966425..905d25d 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1051,7 +1051,7 @@ public abstract class Hop implements ParseInfo
        public enum OpOp1 {
                NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SINH, COSH, TANH, 
SIGN, SQRT, LOG, EXP, 
                CAST_AS_SCALAR, CAST_AS_MATRIX, CAST_AS_FRAME, CAST_AS_DOUBLE, 
CAST_AS_INT, CAST_AS_BOOLEAN,
-               PRINT, EIGEN, NROW, NCOL, LENGTH, ROUND, IQM, STOP, CEIL, 
FLOOR, MEDIAN, INVERSE, CHOLESKY,
+               PRINT, ASSERT, EIGEN, NROW, NCOL, LENGTH, ROUND, IQM, STOP, 
CEIL, FLOOR, MEDIAN, INVERSE, CHOLESKY,
                SVD,
                //cumulative sums, products, extreme values
                CUMSUM, CUMPROD, CUMMIN, CUMMAX,
@@ -1344,6 +1344,7 @@ public abstract class Hop implements ParseInfo
                HopsOpOp1LopsUS.put(OpOp1.NCOL, 
org.apache.sysml.lops.UnaryCP.OperationTypes.NCOL);
                HopsOpOp1LopsUS.put(OpOp1.LENGTH, 
org.apache.sysml.lops.UnaryCP.OperationTypes.LENGTH);
                HopsOpOp1LopsUS.put(OpOp1.PRINT, 
org.apache.sysml.lops.UnaryCP.OperationTypes.PRINT);
+               HopsOpOp1LopsUS.put(OpOp1.ASSERT, 
org.apache.sysml.lops.UnaryCP.OperationTypes.ASSERT);
                HopsOpOp1LopsUS.put(OpOp1.ROUND, 
org.apache.sysml.lops.UnaryCP.OperationTypes.ROUND);
                HopsOpOp1LopsUS.put(OpOp1.CEIL, 
org.apache.sysml.lops.UnaryCP.OperationTypes.CEIL);
                HopsOpOp1LopsUS.put(OpOp1.FLOOR, 
org.apache.sysml.lops.UnaryCP.OperationTypes.FLOOR);
@@ -1389,6 +1390,7 @@ public abstract class Hop implements ParseInfo
                HopsOpOp12String.put(OpOp1.NOT, "!");
                HopsOpOp12String.put(OpOp1.NROW, "nrow");
                HopsOpOp12String.put(OpOp1.PRINT, "print");
+               HopsOpOp12String.put(OpOp1.ASSERT, "assert");
                HopsOpOp12String.put(OpOp1.ROUND, "round");
                HopsOpOp12String.put(OpOp1.SIN, "sin");
                HopsOpOp12String.put(OpOp1.SQRT, "sqrt");

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/UnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java 
b/src/main/java/org/apache/sysml/hops/UnaryOp.java
index ff1b954..c844432 100644
--- a/src/main/java/org/apache/sysml/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java
@@ -689,7 +689,7 @@ public class UnaryOp extends Hop implements MultiThreadedHop
                setRequiresRecompileIfNecessary();
                
                //ensure cp exec type for single-node operations
-               if( _op == OpOp1.PRINT || _op == OpOp1.STOP 
+               if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op == 
OpOp1.STOP 
                        || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == 
OpOp1.CHOLESKY || _op == OpOp1.SVD)
                {
                        _etype = ExecType.CP;

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
index 9cddf57..0dd9480 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
@@ -646,6 +646,7 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                                case NCOL:
                                case NROW:
                                case PRINT:
+                               case ASSERT:
                                case CAST_AS_BOOLEAN:
                                case CAST_AS_DOUBLE:
                                case CAST_AS_INT:

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 68718d3..89bb1e4 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -938,6 +938,7 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                                case NCOL:
                                case NROW:
                                case PRINT:
+                               case ASSERT:
                                case CAST_AS_BOOLEAN:
                                case CAST_AS_DOUBLE:
                                case CAST_AS_INT:

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
index 9fefeb9..b25a671 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
@@ -260,7 +260,8 @@ public class RewriteConstantFolding extends HopRewriteRule
                ArrayList<Hop> in = hop.getInput();
                return (   hop instanceof UnaryOp 
                                && in.get(0) instanceof LiteralOp 
-                               && ((UnaryOp)hop).getOp() != OpOp1.PRINT 
+                               && ((UnaryOp)hop).getOp() != OpOp1.PRINT
+                               && ((UnaryOp)hop).getOp() != OpOp1.ASSERT
                                && ((UnaryOp)hop).getOp() != OpOp1.STOP
                                && hop.getDataType() == DataType.SCALAR);
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveDanglingParentReferences.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveDanglingParentReferences.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveDanglingParentReferences.java
index 573b9fb..824c56b 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveDanglingParentReferences.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveDanglingParentReferences.java
@@ -108,6 +108,7 @@ public class RewriteRemoveDanglingParentReferences extends 
HopRewriteRule
                return (hop instanceof DataOp && ((DataOp)hop).isWrite())
                        || (hop instanceof UnaryOp && 
((UnaryOp)hop).getOp()==OpOp1.STOP)
                        || (hop instanceof UnaryOp && 
((UnaryOp)hop).getOp()==OpOp1.PRINT)
+                       || (hop instanceof UnaryOp && 
((UnaryOp)hop).getOp()==OpOp1.ASSERT)
                        || (hop instanceof NaryOp && 
((NaryOp)hop).getOp()==OpOpN.PRINTF)
                        || (hop instanceof FunctionOp)
                        || (hop instanceof DataOp && 
((DataOp)hop).getDataOpType()==DataOpTypes.FUNCTIONOUTPUT);

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/lops/UnaryCP.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/UnaryCP.java 
b/src/main/java/org/apache/sysml/lops/UnaryCP.java
index 1af06e2..a444178 100644
--- a/src/main/java/org/apache/sysml/lops/UnaryCP.java
+++ b/src/main/java/org/apache/sysml/lops/UnaryCP.java
@@ -36,7 +36,7 @@ public class UnaryCP extends Lop
        public enum OperationTypes {
                NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SQRT, LOG, EXP, 
SINH, COSH, TANH,
                CAST_AS_SCALAR, CAST_AS_MATRIX, CAST_AS_FRAME, CAST_AS_DOUBLE, 
CAST_AS_INT, CAST_AS_BOOLEAN, 
-               PRINT, NROW, NCOL, LENGTH, ROUND, STOP, CEIL, FLOOR, CUMSUM, 
SOFTMAX
+               PRINT, ASSERT, NROW, NCOL, LENGTH, ROUND, STOP, CEIL, FLOOR, 
CUMSUM, SOFTMAX
        }
        
        public static final String CAST_AS_SCALAR_OPCODE = "castdts";
@@ -134,6 +134,9 @@ public class UnaryCP extends Lop
 
                case PRINT:
                        return "print";
+               
+               case ASSERT:
+                       return "assert";
 
                case CAST_AS_MATRIX:
                        return CAST_AS_MATRIX_OPCODE;

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 4422147..9a47d09 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -1215,6 +1215,13 @@ public class DMLTranslator
                                                Hop printHop = new 
UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
                                                printHop.setParseInfo(current);
                                                output.add(printHop);
+                                       } else if (ptype == PRINTTYPE.ASSERT) {
+                                               Hop.OpOp1 op = Hop.OpOp1.ASSERT;
+                                               Expression source = 
ps.getExpressions().get(0);
+                                               Hop ae = 
processExpression(source, target, ids);
+                                               Hop printHop = new 
UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae);
+                                               printHop.setParseInfo(current);
+                                               output.add(printHop);
                                        } else if (ptype == PRINTTYPE.STOP) {
                                                Hop.OpOp1 op = Hop.OpOp1.STOP;
                                                Expression source = 
ps.getExpressions().get(0);

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java 
b/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
index 371b22a..3d829ef 100644
--- a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
+++ b/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
@@ -468,6 +468,10 @@ public class ParForStatementBlock extends ForStatementBlock
                                        raiseValidateError("PARFOR loop 
dependency analysis: " +
                                                           "stop() statement is 
not allowed inside a parfor loop body." , false);
                                }
+                               else if( s instanceof PrintStatement && 
((PrintStatement)s).getType() == PRINTTYPE.ASSERT ) {
+                                       raiseValidateError("PARFOR loop 
dependency analysis: " +
+                                                          "assert() statement 
is not allowed inside a parfor loop body." , false);
+                               }
                                else
                                {
                                        VariableSet vsUpdated = 
s.variablesUpdated();

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/parser/PrintStatement.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/PrintStatement.java 
b/src/main/java/org/apache/sysml/parser/PrintStatement.java
index e964595..4740586 100644
--- a/src/main/java/org/apache/sysml/parser/PrintStatement.java
+++ b/src/main/java/org/apache/sysml/parser/PrintStatement.java
@@ -36,7 +36,7 @@ public class PrintStatement extends Statement
         * built-in function.
         */
        public enum PRINTTYPE {
-               PRINT, PRINTF, STOP
+               PRINT, PRINTF, STOP, ASSERT
        }
 
        protected PRINTTYPE _type; // print, printf, or stop
@@ -50,6 +50,9 @@ public class PrintStatement extends Statement
                                return PRINTTYPE.PRINTF;
                        }
                }
+               else if (type.equalsIgnoreCase("assert")) {
+                       return PRINTTYPE.ASSERT;
+               }
                else if (type.equalsIgnoreCase("stop")) {
                        return PRINTTYPE.STOP;
                }
@@ -105,7 +108,7 @@ public class PrintStatement extends Statement
        public String toString() {
                StringBuilder sb = new StringBuilder();
                sb.append(_type + "(");
-               if ((_type == PRINTTYPE.PRINT) || (_type == PRINTTYPE.STOP)) {
+               if ((_type == PRINTTYPE.PRINT) || (_type == PRINTTYPE.STOP) || 
(_type == PRINTTYPE.ASSERT)) {
                        Expression expression = expressions.get(0);
                        if (expression instanceof StringIdentifier) {
                                sb.append("\"");

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/parser/StatementBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java 
b/src/main/java/org/apache/sysml/parser/StatementBlock.java
index f7f6426..34a023a 100644
--- a/src/main/java/org/apache/sysml/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java
@@ -155,6 +155,9 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
                return _splitDag;
        }
 
+       private boolean isMergeablePrintStatement(Statement stmt) {
+               return ( stmt instanceof PrintStatement && 
(((PrintStatement)stmt).getType() == PRINTTYPE.STOP || 
((PrintStatement)stmt).getType() == PRINTTYPE.ASSERT) );
+       }
 
     public boolean isMergeableFunctionCallBlock(DMLProgram dmlProg) throws 
LanguageException{
 
@@ -167,7 +170,7 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
 
                // Check whether targetIndex block is: control stmt block or 
stmt block for un-mergable function call
                if (   stmt instanceof WhileStatement || stmt instanceof 
IfStatement || stmt instanceof ForStatement
-                       || stmt instanceof FunctionStatement || ( stmt 
instanceof PrintStatement && ((PrintStatement)stmt).getType() == PRINTTYPE.STOP 
)/*|| stmt instanceof ELStatement*/ )
+                       || stmt instanceof FunctionStatement || 
isMergeablePrintStatement(stmt) /*|| stmt instanceof ELStatement*/ )
                {
                        return false;
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java 
b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
index b7a1c89..56642d9 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java
@@ -477,6 +477,7 @@ public class DmlSyntacticValidator extends 
CommonSyntacticValidator implements D
                Set<String> printStatements = new  HashSet<>();
                printStatements.add("print");
                printStatements.add("stop");
+               printStatements.add("assert");
 
                Set<String> outputStatements = new HashSet<>();
                outputStatements.add("write");

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index 4928ac3..169d0b4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -165,6 +165,7 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "sqrt"  , CPType.Unary);
                String2CPInstructionType.put( "plogp" , CPType.Unary);
                String2CPInstructionType.put( "print" , CPType.Unary);
+               String2CPInstructionType.put( "assert" , CPType.Unary);
                String2CPInstructionType.put( "round" , CPType.Unary);
                String2CPInstructionType.put( "ceil"  , CPType.Unary);
                String2CPInstructionType.put( "floor" , CPType.Unary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryScalarCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryScalarCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryScalarCPInstruction.java
index b582fdc..fd15b80 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryScalarCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryScalarCPInstruction.java
@@ -59,6 +59,13 @@ public class UnaryScalarCPInstruction extends 
UnaryMatrixCPInstruction {
                else if ( opcode.equalsIgnoreCase("stop") ) {
                        throw new DMLScriptException(so.getStringValue());
                }
+               else if ( opcode.equalsIgnoreCase("assert") ) {
+                       sores = new BooleanObject(so.getBooleanValue());
+                       if(!so.getBooleanValue()) {
+                               String fileName = this.getFilename() == null ? 
"" : this.getFilename() + " "; 
+                               throw new DMLScriptException("assertion failed 
at " + fileName  + this.getBeginLine() + ":" + this.getBeginColumn() + "-" + 
this.getEndLine() + ":" + this.getEndColumn());
+                       }
+               }
                else {
                        UnaryOperator dop = (UnaryOperator) _optr;
                        if ( so instanceof IntObject && output.getValueType() 
== ValueType.INT )

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala 
b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index 0a215b1..5d17a4d 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -305,6 +305,7 @@ class Caffe2DML(val sc: SparkContext,
 
   // TODO: throw error or warning if user tries to set solver_mode == GPU 
instead of using setGPU method
 
+  def containsParfor():Boolean = 
getTrainAlgo.toLowerCase.startsWith("allreduce") || 
getTestAlgo.toLowerCase.startsWith("allreduce")
   def getTrainAlgo(): String = if (inputs.containsKey("$train_algo")) 
inputs.get("$train_algo") else "minibatch"
   def getTestAlgo(): String  = if (inputs.containsKey("$test_algo")) 
inputs.get("$test_algo") else "minibatch"
 
@@ -360,6 +361,7 @@ class Caffe2DML(val sc: SparkContext,
   
   def setDebugFlags(isDebug:Boolean):Unit = {
     net.getLayers.map(layer => {net.getCaffeLayer(layer).debugLayer = isDebug})
+    net.getLayers.map(layer => {net.getCaffeLayer(layer).caffe2dmlObj = this})
   }
 
   // 
================================================================================================

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala 
b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index 65a9921..37b585f 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -43,21 +43,7 @@ trait CaffeLayer extends BaseDMLGenerator {
   }
   // -------------------------------------------------
   var debugLayer = false
-  def validateDimensions(dmlScript: StringBuilder, mat:String, 
expectedNumRows:String, expectedNumCols:String, optionalString:String=""):Unit 
= {
-    if(debugLayer) {
-      val msg = " in " + sourceFileName + "(" + optionalString + ") script."
-      if(expectedNumRows != null) {
-        dmlScript.append("\nif( " + expectedNumRows + " != nrow(" + mat + ")) 
{\n")
-        dmlScript.append("\tstop(\"Incorrect number of rows for " + mat + msg 
+ " Expected:\" + " + expectedNumRows + " + \" but found \" +  nrow(" + mat + 
") )") 
-        dmlScript.append("\n}\n")
-      }
-      if(expectedNumCols != null) {
-        dmlScript.append("\nif( " + expectedNumCols + " != ncol(" + mat + ")) 
{\n")
-        dmlScript.append("\tstop(\"Incorrect number of columns for " + mat + 
msg + " Expected:\" + " + expectedNumCols + " + \" but found \" +  ncol(" + mat 
+ ") )") 
-        dmlScript.append("\n}\n")
-      }
-    }
-  }
+  var caffe2dmlObj:Caffe2DML = null
   var computedBottomLayerOutputShape: (String, String, String) = null
   def bottomLayerOutputShape: (String, String, String) = {
     if (computedBottomLayerOutputShape == null) {
@@ -868,13 +854,12 @@ class InnerProduct(val param: LayerParameter, val id: 
Int, val net: CaffeNetwork
    *  - out: Outputs, of shape (N, M).
    */
   override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = {
-    val D = numFeatures
-    val M = numNeurons
-    validateDimensions(dmlScript, X, null, D)
-    validateDimensions(dmlScript, weight, D, M, "forward")
-    validateDimensions(dmlScript, bias, "1", M)
+    if(debugLayer && caffe2dmlObj != null && !caffe2dmlObj.containsParfor) {
+      dmlScript.append("assert(ncol(" + X + ") == nrow(" + weight + ") | 
ncol(" + weight + ") == ncol(" + bias + ")); ")
+    }
     invokeForward(dmlScript, List[String](out), X, weight, bias)
   }
+    
   /*
    * Computes the backward pass for a fully-connected (affine) layer
    * with M neurons.
@@ -890,15 +875,9 @@ class InnerProduct(val param: LayerParameter, val id: Int, 
val net: CaffeNetwork
    *  - dW: Gradient wrt `W`, of shape (D, M).
    *  - db: Gradient wrt `b`, of shape (1, M).
    */
-  override def backward(dmlScript: StringBuilder, outSuffix: String) = {
-    val D = numFeatures
-    val M = numNeurons
-    validateDimensions(dmlScript, dout, null, M)
-    validateDimensions(dmlScript, X, null, D)
-    validateDimensions(dmlScript, weight, D, M, "backward")
-    validateDimensions(dmlScript, bias, "1", M)
+  override def backward(dmlScript: StringBuilder, outSuffix: String) = 
     invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, 
dBias), dout, X, weight, bias)
-  }
+  
   // -------------------------------------------------
   // num_output (c_o): the number of filters
   def numNeurons  = param.getInnerProductParam.getNumOutput.toString
@@ -970,34 +949,12 @@ class LSTM(val param: LayerParameter, val id: Int, val 
net: CaffeNetwork) extend
     val N:String = null // output_features.toString
     val T = timesteps()
     val D = input_features()
-    validateDimensions(dmlScript, X, N, T + "*" + D)
-    validateDimensions(dmlScript, out0, N, M)
-    validateDimensions(dmlScript, c0, N, M)
-    validateDimensions(dmlScript, weight, D + "+" + M, 4 + "*" + M)
-    validateDimensions(dmlScript, bias, "1", 4 + "*" + M)
     invokeForward(dmlScript, List[String](out, c, cache_out, cache_c, 
cache_ifog), X, weight, bias, T, D, return_sequences.toString.toUpperCase, 
out0, c0)
-    // This validates whether the output is of correct dimensions
-    validateDimensions(dmlScript, out, null, int_mult(outputShape._1, 
outputShape._2, outputShape._3))
   }
   
   override def backward(dmlScript: StringBuilder, outSuffix: String) = {
     val T = timesteps()
     val D = input_features()
-    if(return_sequences) {
-      validateDimensions(dmlScript, dout, null, T + "*" + M)
-    }
-    else {
-      validateDimensions(dmlScript, dout, null, M)
-    }
-    validateDimensions(dmlScript, dc0, null, M)
-    validateDimensions(dmlScript, X, null, T + "*" + D)
-    validateDimensions(dmlScript, out0, null, M)
-    validateDimensions(dmlScript, c0, null, M)
-    validateDimensions(dmlScript, cache_out, T, null)
-    validateDimensions(dmlScript, cache_c, T, null)
-    validateDimensions(dmlScript, cache_ifog, T, null)
-    validateDimensions(dmlScript, weight, D + "+" + M, 4 + "*" + M)
-    validateDimensions(dmlScript, bias, "1", 4 + "*" + M)
     invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, 
dBias, dout0, dc0), dout, dc0, X, weight, bias,
         T, D, return_sequences.toString.toUpperCase, out0, c0, cache_out, 
cache_c, cache_ifog)
   }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/test/java/org/apache/sysml/test/integration/functions/misc/AssertExpressionTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/AssertExpressionTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/AssertExpressionTest.java
new file mode 100644
index 0000000..adda447
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/AssertExpressionTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import org.junit.Test;
+
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+/**
+ * 
+ */
+public class AssertExpressionTest extends AutomatedTestBase 
+{
+       private final static String TEST_NAME1 = "AssertExpressionTest1";
+       private final static String TEST_DIR = "functions/misc/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
AssertExpressionTest.class.getSimpleName() + "/";
+       
+       @Override
+       public void setUp() 
+       {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" })); 
+       }
+               
+       @Test
+       public void testPrintNotExpressionTest() {
+               runPrintExpressionTest(TEST_NAME1, false);
+       }
+       
+       @Test
+       public void testPrintNotExpressionTestRewrite() {
+               runPrintExpressionTest(TEST_NAME1, true);
+       }
+               
+       /**
+        * 
+        * @param testname
+        * @param rewrites
+        */
+       private void runPrintExpressionTest( String testname, boolean rewrites )
+       {
+               String TEST_NAME = testname;
+               TestConfiguration config = getTestConfiguration(TEST_NAME);
+               loadTestConfiguration(config);
+               
+               //set rewrite configuration
+               boolean oldRewriteFlag = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+               
+               try
+               {
+                       String HOME = SCRIPT_DIR + TEST_DIR;                    
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-args", output("R")};
+                       
+                       fullRScriptName = HOME + TEST_NAME +".R";
+                       rCmd = getRCmd(expectedDir());
+       
+                       //run Tests
+                       runTest(true, false, null, -1);
+               }
+               finally
+               {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
oldRewriteFlag;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/test/scripts/functions/misc/AssertExpressionTest1.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/AssertExpressionTest1.dml 
b/src/test/scripts/functions/misc/AssertExpressionTest1.dml
new file mode 100644
index 0000000..d7c30bf
--- /dev/null
+++ b/src/test/scripts/functions/misc/AssertExpressionTest1.dml
@@ -0,0 +1,22 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+assert(1 != 2);

Reply via email to