Repository: systemml Updated Branches: refs/heads/master f702c03be -> 0529350a3
[SYSTEMML-540] Added an assert builtin function - Assert function halts the execution of DML program if the boolean argument doesnot evaluate to TRUE. - Like stop, assert is not supported inside a parfor. - Caffe2DML inserts assert function in non-parfor cases for affine layer. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0529350a Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0529350a Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0529350a Branch: refs/heads/master Commit: 0529350a3dfa665bed040da255388f513b5868d2 Parents: f702c03 Author: Niketan Pansare <[email protected]> Authored: Fri Jan 19 14:46:26 2018 -0800 Committer: Niketan Pansare <[email protected]> Committed: Fri Jan 19 14:46:26 2018 -0800 ---------------------------------------------------------------------- docs/dml-language-reference.md | 1 + src/main/java/org/apache/sysml/hops/Hop.java | 4 +- .../java/org/apache/sysml/hops/UnaryOp.java | 2 +- .../codegen/opt/PlanSelectionFuseCostBased.java | 1 + .../opt/PlanSelectionFuseCostBasedV2.java | 1 + .../hops/rewrite/RewriteConstantFolding.java | 3 +- .../RewriteRemoveDanglingParentReferences.java | 1 + .../java/org/apache/sysml/lops/UnaryCP.java | 5 +- .../org/apache/sysml/parser/DMLTranslator.java | 7 ++ .../sysml/parser/ParForStatementBlock.java | 4 + .../org/apache/sysml/parser/PrintStatement.java | 7 +- .../org/apache/sysml/parser/StatementBlock.java | 5 +- .../sysml/parser/dml/DmlSyntacticValidator.java | 1 + .../instructions/CPInstructionParser.java | 1 + .../cp/UnaryScalarCPInstruction.java | 7 ++ .../org/apache/sysml/api/dl/Caffe2DML.scala | 2 + .../org/apache/sysml/api/dl/CaffeLayer.scala | 57 ++----------- .../functions/misc/AssertExpressionTest.java | 87 ++++++++++++++++++++ .../functions/misc/AssertExpressionTest1.dml | 22 +++++ 19 files changed, 161 insertions(+), 57 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/docs/dml-language-reference.md ---------------------------------------------------------------------- diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md index a0cc0f7..5fb9de5 100644 --- a/docs/dml-language-reference.md +++ b/docs/dml-language-reference.md @@ -1542,6 +1542,7 @@ append() | Append a string to another string separated by "\n" <br/> Limitation: toString() | Formats a Matrix or Frame object into a string. <br/> "rows" & "cols" : number of rows and columns to print<br/> "decimal" : number of digits after the decimal<br/>"sparse" : set to TRUE to print Matrix object in sparse format, i.e. _RowIndex_ _ColIndex_ _Value_<br/>"sep" and "linesep" : inter-element separator and the line separator strings| Input : (<matrix> or <frame>,<br/> rows=100,<br/> cols=100,<br/> decimal=3,<br/> sparse=FALSE,<br/> sep=" ",<br/> linesep="\n") <br/> Output: <string> | X = matrix(seq(1, 9), rows=3, cols=3)<br/>str = toString(X, sep=" \| ") <br/><br/>F = as.frame(X)<br/>print(toString(F, rows=2, cols=2)) print() | Prints a scalar variable. The print() function allows printf-style formatting by optionally allowing multiple arguments, where the first argument is the string that specifies the formatting and the additional arguments are the arguments to format. | Input: <scalar><br/>or<br/><string, args...> | print("hello") <br/> print("hello" + "world") <br/> print("value of x is " + x ) <br/><br/>a='hello';<br/>b=3;<br/>c=4.5;<br/>d=TRUE;<br/>print('%s %d %f %b', a, b, c, d); <br/><br/>a='hello';<br/>b='goodbye';<br/>c=4;<br/>d=3;<br/>e=3.0;<br/>f=5.0;<br/>g=FALSE;<br/>print('%s %d %f %b', (a+b), (c-d), (e*f), !g); stop() | Halts the execution of DML program by printing the message that is passed in as the argument. <br/> Note that the use of stop() is not allowed inside a parfor loop. | Input: (<scalar>) | stop("Inputs to DML program are invalid") <br/> stop("Class labels must be either -1 or +1") +assert() | Halts the execution of DML program if the boolean argument doesnot evaluate to TRUE. <br/> Note that the use of assert() is not allowed inside a parfor loop. | Input: (<scalar of type boolean>) | assert(1 != 2) order() | Sort a column of the matrix X in decreasing/increasing order and return either index (index.return=TRUE) or data (index.return=FALSE). | Input: (target=X, by=column, decreasing, index.return) | order(X, by=1, decreasing=FALSE, index.return=FALSE) http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index a966425..905d25d 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1051,7 +1051,7 @@ public abstract class Hop implements ParseInfo public enum OpOp1 { NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SINH, COSH, TANH, SIGN, SQRT, LOG, EXP, CAST_AS_SCALAR, CAST_AS_MATRIX, CAST_AS_FRAME, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN, - PRINT, EIGEN, NROW, NCOL, LENGTH, ROUND, IQM, STOP, CEIL, FLOOR, MEDIAN, INVERSE, CHOLESKY, + PRINT, ASSERT, EIGEN, NROW, NCOL, LENGTH, ROUND, IQM, STOP, CEIL, FLOOR, MEDIAN, INVERSE, CHOLESKY, SVD, //cumulative sums, products, extreme values CUMSUM, CUMPROD, CUMMIN, CUMMAX, @@ -1344,6 +1344,7 @@ public abstract class Hop implements ParseInfo HopsOpOp1LopsUS.put(OpOp1.NCOL, org.apache.sysml.lops.UnaryCP.OperationTypes.NCOL); HopsOpOp1LopsUS.put(OpOp1.LENGTH, org.apache.sysml.lops.UnaryCP.OperationTypes.LENGTH); HopsOpOp1LopsUS.put(OpOp1.PRINT, org.apache.sysml.lops.UnaryCP.OperationTypes.PRINT); + HopsOpOp1LopsUS.put(OpOp1.ASSERT, org.apache.sysml.lops.UnaryCP.OperationTypes.ASSERT); HopsOpOp1LopsUS.put(OpOp1.ROUND, org.apache.sysml.lops.UnaryCP.OperationTypes.ROUND); HopsOpOp1LopsUS.put(OpOp1.CEIL, org.apache.sysml.lops.UnaryCP.OperationTypes.CEIL); HopsOpOp1LopsUS.put(OpOp1.FLOOR, org.apache.sysml.lops.UnaryCP.OperationTypes.FLOOR); @@ -1389,6 +1390,7 @@ public abstract class Hop implements ParseInfo HopsOpOp12String.put(OpOp1.NOT, "!"); HopsOpOp12String.put(OpOp1.NROW, "nrow"); HopsOpOp12String.put(OpOp1.PRINT, "print"); + HopsOpOp12String.put(OpOp1.ASSERT, "assert"); HopsOpOp12String.put(OpOp1.ROUND, "round"); HopsOpOp12String.put(OpOp1.SIN, "sin"); HopsOpOp12String.put(OpOp1.SQRT, "sqrt"); http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/UnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java b/src/main/java/org/apache/sysml/hops/UnaryOp.java index ff1b954..c844432 100644 --- a/src/main/java/org/apache/sysml/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java @@ -689,7 +689,7 @@ public class UnaryOp extends Hop implements MultiThreadedHop setRequiresRecompileIfNecessary(); //ensure cp exec type for single-node operations - if( _op == OpOp1.PRINT || _op == OpOp1.STOP + if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op == OpOp1.STOP || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD) { _etype = ExecType.CP; http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java index 9cddf57..0dd9480 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java @@ -646,6 +646,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection case NCOL: case NROW: case PRINT: + case ASSERT: case CAST_AS_BOOLEAN: case CAST_AS_DOUBLE: case CAST_AS_INT: http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 68718d3..89bb1e4 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -938,6 +938,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection case NCOL: case NROW: case PRINT: + case ASSERT: case CAST_AS_BOOLEAN: case CAST_AS_DOUBLE: case CAST_AS_INT: http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java index 9fefeb9..b25a671 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java @@ -260,7 +260,8 @@ public class RewriteConstantFolding extends HopRewriteRule ArrayList<Hop> in = hop.getInput(); return ( hop instanceof UnaryOp && in.get(0) instanceof LiteralOp - && ((UnaryOp)hop).getOp() != OpOp1.PRINT + && ((UnaryOp)hop).getOp() != OpOp1.PRINT + && ((UnaryOp)hop).getOp() != OpOp1.ASSERT && ((UnaryOp)hop).getOp() != OpOp1.STOP && hop.getDataType() == DataType.SCALAR); } http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveDanglingParentReferences.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveDanglingParentReferences.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveDanglingParentReferences.java index 573b9fb..824c56b 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveDanglingParentReferences.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveDanglingParentReferences.java @@ -108,6 +108,7 @@ public class RewriteRemoveDanglingParentReferences extends HopRewriteRule return (hop instanceof DataOp && ((DataOp)hop).isWrite()) || (hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==OpOp1.STOP) || (hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==OpOp1.PRINT) + || (hop instanceof UnaryOp && ((UnaryOp)hop).getOp()==OpOp1.ASSERT) || (hop instanceof NaryOp && ((NaryOp)hop).getOp()==OpOpN.PRINTF) || (hop instanceof FunctionOp) || (hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.FUNCTIONOUTPUT); http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/lops/UnaryCP.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/UnaryCP.java b/src/main/java/org/apache/sysml/lops/UnaryCP.java index 1af06e2..a444178 100644 --- a/src/main/java/org/apache/sysml/lops/UnaryCP.java +++ b/src/main/java/org/apache/sysml/lops/UnaryCP.java @@ -36,7 +36,7 @@ public class UnaryCP extends Lop public enum OperationTypes { NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SQRT, LOG, EXP, SINH, COSH, TANH, CAST_AS_SCALAR, CAST_AS_MATRIX, CAST_AS_FRAME, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN, - PRINT, NROW, NCOL, LENGTH, ROUND, STOP, CEIL, FLOOR, CUMSUM, SOFTMAX + PRINT, ASSERT, NROW, NCOL, LENGTH, ROUND, STOP, CEIL, FLOOR, CUMSUM, SOFTMAX } public static final String CAST_AS_SCALAR_OPCODE = "castdts"; @@ -134,6 +134,9 @@ public class UnaryCP extends Lop case PRINT: return "print"; + + case ASSERT: + return "assert"; case CAST_AS_MATRIX: return CAST_AS_MATRIX_OPCODE; http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 4422147..9a47d09 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -1215,6 +1215,13 @@ public class DMLTranslator Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae); printHop.setParseInfo(current); output.add(printHop); + } else if (ptype == PRINTTYPE.ASSERT) { + Hop.OpOp1 op = Hop.OpOp1.ASSERT; + Expression source = ps.getExpressions().get(0); + Hop ae = processExpression(source, target, ids); + Hop printHop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), op, ae); + printHop.setParseInfo(current); + output.add(printHop); } else if (ptype == PRINTTYPE.STOP) { Hop.OpOp1 op = Hop.OpOp1.STOP; Expression source = ps.getExpressions().get(0); http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java b/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java index 371b22a..3d829ef 100644 --- a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java @@ -468,6 +468,10 @@ public class ParForStatementBlock extends ForStatementBlock raiseValidateError("PARFOR loop dependency analysis: " + "stop() statement is not allowed inside a parfor loop body." , false); } + else if( s instanceof PrintStatement && ((PrintStatement)s).getType() == PRINTTYPE.ASSERT ) { + raiseValidateError("PARFOR loop dependency analysis: " + + "assert() statement is not allowed inside a parfor loop body." , false); + } else { VariableSet vsUpdated = s.variablesUpdated(); http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/parser/PrintStatement.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/PrintStatement.java b/src/main/java/org/apache/sysml/parser/PrintStatement.java index e964595..4740586 100644 --- a/src/main/java/org/apache/sysml/parser/PrintStatement.java +++ b/src/main/java/org/apache/sysml/parser/PrintStatement.java @@ -36,7 +36,7 @@ public class PrintStatement extends Statement * built-in function. */ public enum PRINTTYPE { - PRINT, PRINTF, STOP + PRINT, PRINTF, STOP, ASSERT } protected PRINTTYPE _type; // print, printf, or stop @@ -50,6 +50,9 @@ public class PrintStatement extends Statement return PRINTTYPE.PRINTF; } } + else if (type.equalsIgnoreCase("assert")) { + return PRINTTYPE.ASSERT; + } else if (type.equalsIgnoreCase("stop")) { return PRINTTYPE.STOP; } @@ -105,7 +108,7 @@ public class PrintStatement extends Statement public String toString() { StringBuilder sb = new StringBuilder(); sb.append(_type + "("); - if ((_type == PRINTTYPE.PRINT) || (_type == PRINTTYPE.STOP)) { + if ((_type == PRINTTYPE.PRINT) || (_type == PRINTTYPE.STOP) || (_type == PRINTTYPE.ASSERT)) { Expression expression = expressions.get(0); if (expression instanceof StringIdentifier) { sb.append("\""); http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/parser/StatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java b/src/main/java/org/apache/sysml/parser/StatementBlock.java index f7f6426..34a023a 100644 --- a/src/main/java/org/apache/sysml/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java @@ -155,6 +155,9 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo return _splitDag; } + private boolean isMergeablePrintStatement(Statement stmt) { + return ( stmt instanceof PrintStatement && (((PrintStatement)stmt).getType() == PRINTTYPE.STOP || ((PrintStatement)stmt).getType() == PRINTTYPE.ASSERT) ); + } public boolean isMergeableFunctionCallBlock(DMLProgram dmlProg) throws LanguageException{ @@ -167,7 +170,7 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo // Check whether targetIndex block is: control stmt block or stmt block for un-mergable function call if ( stmt instanceof WhileStatement || stmt instanceof IfStatement || stmt instanceof ForStatement - || stmt instanceof FunctionStatement || ( stmt instanceof PrintStatement && ((PrintStatement)stmt).getType() == PRINTTYPE.STOP )/*|| stmt instanceof ELStatement*/ ) + || stmt instanceof FunctionStatement || isMergeablePrintStatement(stmt) /*|| stmt instanceof ELStatement*/ ) { return false; } http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java index b7a1c89..56642d9 100644 --- a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java +++ b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java @@ -477,6 +477,7 @@ public class DmlSyntacticValidator extends CommonSyntacticValidator implements D Set<String> printStatements = new HashSet<>(); printStatements.add("print"); printStatements.add("stop"); + printStatements.add("assert"); Set<String> outputStatements = new HashSet<>(); outputStatements.add("write"); http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index 4928ac3..169d0b4 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -165,6 +165,7 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "sqrt" , CPType.Unary); String2CPInstructionType.put( "plogp" , CPType.Unary); String2CPInstructionType.put( "print" , CPType.Unary); + String2CPInstructionType.put( "assert" , CPType.Unary); String2CPInstructionType.put( "round" , CPType.Unary); String2CPInstructionType.put( "ceil" , CPType.Unary); String2CPInstructionType.put( "floor" , CPType.Unary); http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryScalarCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryScalarCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryScalarCPInstruction.java index b582fdc..fd15b80 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryScalarCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/UnaryScalarCPInstruction.java @@ -59,6 +59,13 @@ public class UnaryScalarCPInstruction extends UnaryMatrixCPInstruction { else if ( opcode.equalsIgnoreCase("stop") ) { throw new DMLScriptException(so.getStringValue()); } + else if ( opcode.equalsIgnoreCase("assert") ) { + sores = new BooleanObject(so.getBooleanValue()); + if(!so.getBooleanValue()) { + String fileName = this.getFilename() == null ? "" : this.getFilename() + " "; + throw new DMLScriptException("assertion failed at " + fileName + this.getBeginLine() + ":" + this.getBeginColumn() + "-" + this.getEndLine() + ":" + this.getEndColumn()); + } + } else { UnaryOperator dop = (UnaryOperator) _optr; if ( so instanceof IntObject && output.getValueType() == ValueType.INT ) http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala index 0a215b1..5d17a4d 100644 --- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala +++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala @@ -305,6 +305,7 @@ class Caffe2DML(val sc: SparkContext, // TODO: throw error or warning if user tries to set solver_mode == GPU instead of using setGPU method + def containsParfor():Boolean = getTrainAlgo.toLowerCase.startsWith("allreduce") || getTestAlgo.toLowerCase.startsWith("allreduce") def getTrainAlgo(): String = if (inputs.containsKey("$train_algo")) inputs.get("$train_algo") else "minibatch" def getTestAlgo(): String = if (inputs.containsKey("$test_algo")) inputs.get("$test_algo") else "minibatch" @@ -360,6 +361,7 @@ class Caffe2DML(val sc: SparkContext, def setDebugFlags(isDebug:Boolean):Unit = { net.getLayers.map(layer => {net.getCaffeLayer(layer).debugLayer = isDebug}) + net.getLayers.map(layer => {net.getCaffeLayer(layer).caffe2dmlObj = this}) } // ================================================================================================ http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala index 65a9921..37b585f 100644 --- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala +++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala @@ -43,21 +43,7 @@ trait CaffeLayer extends BaseDMLGenerator { } // ------------------------------------------------- var debugLayer = false - def validateDimensions(dmlScript: StringBuilder, mat:String, expectedNumRows:String, expectedNumCols:String, optionalString:String=""):Unit = { - if(debugLayer) { - val msg = " in " + sourceFileName + "(" + optionalString + ") script." - if(expectedNumRows != null) { - dmlScript.append("\nif( " + expectedNumRows + " != nrow(" + mat + ")) {\n") - dmlScript.append("\tstop(\"Incorrect number of rows for " + mat + msg + " Expected:\" + " + expectedNumRows + " + \" but found \" + nrow(" + mat + ") )") - dmlScript.append("\n}\n") - } - if(expectedNumCols != null) { - dmlScript.append("\nif( " + expectedNumCols + " != ncol(" + mat + ")) {\n") - dmlScript.append("\tstop(\"Incorrect number of columns for " + mat + msg + " Expected:\" + " + expectedNumCols + " + \" but found \" + ncol(" + mat + ") )") - dmlScript.append("\n}\n") - } - } - } + var caffe2dmlObj:Caffe2DML = null var computedBottomLayerOutputShape: (String, String, String) = null def bottomLayerOutputShape: (String, String, String) = { if (computedBottomLayerOutputShape == null) { @@ -868,13 +854,12 @@ class InnerProduct(val param: LayerParameter, val id: Int, val net: CaffeNetwork * - out: Outputs, of shape (N, M). */ override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = { - val D = numFeatures - val M = numNeurons - validateDimensions(dmlScript, X, null, D) - validateDimensions(dmlScript, weight, D, M, "forward") - validateDimensions(dmlScript, bias, "1", M) + if(debugLayer && caffe2dmlObj != null && !caffe2dmlObj.containsParfor) { + dmlScript.append("assert(ncol(" + X + ") == nrow(" + weight + ") | ncol(" + weight + ") == ncol(" + bias + ")); ") + } invokeForward(dmlScript, List[String](out), X, weight, bias) } + /* * Computes the backward pass for a fully-connected (affine) layer * with M neurons. @@ -890,15 +875,9 @@ class InnerProduct(val param: LayerParameter, val id: Int, val net: CaffeNetwork * - dW: Gradient wrt `W`, of shape (D, M). * - db: Gradient wrt `b`, of shape (1, M). */ - override def backward(dmlScript: StringBuilder, outSuffix: String) = { - val D = numFeatures - val M = numNeurons - validateDimensions(dmlScript, dout, null, M) - validateDimensions(dmlScript, X, null, D) - validateDimensions(dmlScript, weight, D, M, "backward") - validateDimensions(dmlScript, bias, "1", M) + override def backward(dmlScript: StringBuilder, outSuffix: String) = invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), dout, X, weight, bias) - } + // ------------------------------------------------- // num_output (c_o): the number of filters def numNeurons = param.getInnerProductParam.getNumOutput.toString @@ -970,34 +949,12 @@ class LSTM(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extend val N:String = null // output_features.toString val T = timesteps() val D = input_features() - validateDimensions(dmlScript, X, N, T + "*" + D) - validateDimensions(dmlScript, out0, N, M) - validateDimensions(dmlScript, c0, N, M) - validateDimensions(dmlScript, weight, D + "+" + M, 4 + "*" + M) - validateDimensions(dmlScript, bias, "1", 4 + "*" + M) invokeForward(dmlScript, List[String](out, c, cache_out, cache_c, cache_ifog), X, weight, bias, T, D, return_sequences.toString.toUpperCase, out0, c0) - // This validates whether the output is of correct dimensions - validateDimensions(dmlScript, out, null, int_mult(outputShape._1, outputShape._2, outputShape._3)) } override def backward(dmlScript: StringBuilder, outSuffix: String) = { val T = timesteps() val D = input_features() - if(return_sequences) { - validateDimensions(dmlScript, dout, null, T + "*" + M) - } - else { - validateDimensions(dmlScript, dout, null, M) - } - validateDimensions(dmlScript, dc0, null, M) - validateDimensions(dmlScript, X, null, T + "*" + D) - validateDimensions(dmlScript, out0, null, M) - validateDimensions(dmlScript, c0, null, M) - validateDimensions(dmlScript, cache_out, T, null) - validateDimensions(dmlScript, cache_c, T, null) - validateDimensions(dmlScript, cache_ifog, T, null) - validateDimensions(dmlScript, weight, D + "+" + M, 4 + "*" + M) - validateDimensions(dmlScript, bias, "1", 4 + "*" + M) invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias, dout0, dc0), dout, dc0, X, weight, bias, T, D, return_sequences.toString.toUpperCase, out0, c0, cache_out, cache_c, cache_ifog) } http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/test/java/org/apache/sysml/test/integration/functions/misc/AssertExpressionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/AssertExpressionTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/AssertExpressionTest.java new file mode 100644 index 0000000..adda447 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/AssertExpressionTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import org.junit.Test; + +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +/** + * + */ +public class AssertExpressionTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "AssertExpressionTest1"; + private final static String TEST_DIR = "functions/misc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + AssertExpressionTest.class.getSimpleName() + "/"; + + @Override + public void setUp() + { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" })); + } + + @Test + public void testPrintNotExpressionTest() { + runPrintExpressionTest(TEST_NAME1, false); + } + + @Test + public void testPrintNotExpressionTestRewrite() { + runPrintExpressionTest(TEST_NAME1, true); + } + + /** + * + * @param testname + * @param rewrites + */ + private void runPrintExpressionTest( String testname, boolean rewrites ) + { + String TEST_NAME = testname; + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + //set rewrite configuration + boolean oldRewriteFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + try + { + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{"-args", output("R")}; + + fullRScriptName = HOME + TEST_NAME +".R"; + rCmd = getRCmd(expectedDir()); + + //run Tests + runTest(true, false, null, -1); + } + finally + { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewriteFlag; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/0529350a/src/test/scripts/functions/misc/AssertExpressionTest1.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/AssertExpressionTest1.dml b/src/test/scripts/functions/misc/AssertExpressionTest1.dml new file mode 100644 index 0000000..d7c30bf --- /dev/null +++ b/src/test/scripts/functions/misc/AssertExpressionTest1.dml @@ -0,0 +1,22 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +assert(1 != 2);
