Repository: systemml Updated Branches: refs/heads/master 276065f93 -> 9861d7a3c
[SYSTEMML-2340] Fix invalid parser constant propagation (conditional) This patch fixes special cases of invalid constant propagation during parsing, which prevents IPA (which is the primary component for size propagation) to correct for invalid sizes (e.g., in conditional for/while). We now disable parser constant propagation in conditional control flow, which simply defers it to IPA and rewrites. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9861d7a3 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9861d7a3 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9861d7a3 Branch: refs/heads/master Commit: 9861d7a3ce8eacb6b20c714ad6615eb1bc377a39 Parents: 276065f Author: Matthias Boehm <[email protected]> Authored: Fri Jun 1 20:14:04 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Jun 1 20:40:25 2018 -0700 ---------------------------------------------------------------------- .../apache/sysml/parser/BinaryExpression.java | 11 +++-- .../sysml/parser/BuiltinFunctionExpression.java | 24 +++++----- .../org/apache/sysml/parser/DMLTranslator.java | 3 +- .../org/apache/sysml/parser/DataExpression.java | 8 ++-- .../apache/sysml/parser/ForStatementBlock.java | 32 ++++++------- .../sysml/parser/FunctionCallIdentifier.java | 16 ++++--- .../apache/sysml/parser/IfStatementBlock.java | 2 +- .../apache/sysml/parser/IndexedIdentifier.java | 8 ++-- .../sysml/parser/RelationalExpression.java | 10 ++-- .../functions/misc/SizePropagationTest.java | 12 +++++ .../recompile/RandSizeExpressionEvalTest.java | 49 ++++++++------------ .../functions/misc/SizePropagationLoopIx4.dml | 32 +++++++++++++ 12 files changed, 122 insertions(+), 85 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/main/java/org/apache/sysml/parser/BinaryExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BinaryExpression.java b/src/main/java/org/apache/sysml/parser/BinaryExpression.java index 83dc3d0..0ee7441 100644 --- a/src/main/java/org/apache/sysml/parser/BinaryExpression.java +++ b/src/main/java/org/apache/sysml/parser/BinaryExpression.java @@ -101,11 +101,12 @@ public class BinaryExpression extends Expression _right.validateExpression(ids, constVars, conditional); //constant propagation (precondition for more complex constant folding rewrite) - if( _left instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _left).getName()) ) - _left = constVars.get(((DataIdentifier) _left).getName()); - if( _right instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _right).getName()) ) - _right = constVars.get(((DataIdentifier) _right).getName()); - + if( !conditional ) { + if( _left instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _left).getName()) ) + _left = constVars.get(((DataIdentifier) _left).getName()); + if( _right instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _right).getName()) ) + _right = constVars.get(((DataIdentifier) _right).getName()); + } String outputName = getTempName(); DataIdentifier output = new DataIdentifier(outputName); http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index 8a0a6d8..42b8329 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -462,11 +462,9 @@ public class BuiltinFunctionExpression extends DataIdentifier public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<String, ConstIdentifier> constVars, boolean conditional) { for(int i=0; i < _args.length; i++ ) { - if (_args[i] instanceof FunctionCallIdentifier){ raiseValidateError("UDF function call not supported as parameter to built-in function call", false); } - _args[i].validateExpression(ids, constVars, conditional); } @@ -876,9 +874,9 @@ public class BuiltinFunctionExpression extends DataIdentifier } else { // constant propagation - if( getThirdExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getThirdExpr()).getName()) ) + if( getThirdExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getThirdExpr()).getName()) && !conditional ) _args[2] = constVars.get(((DataIdentifier)getThirdExpr()).getName()); - if( _args[3] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[3]).getName()) ) + if( _args[3] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[3]).getName()) && !conditional ) _args[3] = constVars.get(((DataIdentifier)_args[3]).getName()); if ( getThirdExpr().getOutput() instanceof ConstIdentifier ) @@ -903,9 +901,9 @@ public class BuiltinFunctionExpression extends DataIdentifier } else { // constant propagation - if( _args[3] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[3]).getName()) ) + if( _args[3] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[3]).getName()) && !conditional ) _args[3] = constVars.get(((DataIdentifier)_args[3]).getName()); - if( _args[4] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[4]).getName()) ) + if( _args[4] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[4]).getName()) && !conditional ) _args[4] = constVars.get(((DataIdentifier)_args[4]).getName()); if ( _args[3].getOutput() instanceof ConstIdentifier ) @@ -1164,12 +1162,14 @@ public class BuiltinFunctionExpression extends DataIdentifier checkNumParameters(2); // constant propagation (from, to, incr) - if( getFirstExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getFirstExpr()).getName()) ) - _args[0] = constVars.get(((DataIdentifier)getFirstExpr()).getName()); - if( getSecondExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getSecondExpr()).getName()) ) - _args[1] = constVars.get(((DataIdentifier)getSecondExpr()).getName()); - if( getThirdExpr()!=null && getThirdExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getThirdExpr()).getName()) ) - _args[2] = constVars.get(((DataIdentifier)getThirdExpr()).getName()); + if( !conditional ) { + if( getFirstExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getFirstExpr()).getName()) ) + _args[0] = constVars.get(((DataIdentifier)getFirstExpr()).getName()); + if( getSecondExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getSecondExpr()).getName()) ) + _args[1] = constVars.get(((DataIdentifier)getSecondExpr()).getName()); + if( getThirdExpr()!=null && getThirdExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getThirdExpr()).getName()) ) + _args[2] = constVars.get(((DataIdentifier)getThirdExpr()).getName()); + } // check if dimensions can be inferred long dim1=-1, dim2=1; http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index fd0b47b..47c28a2 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -2352,7 +2352,6 @@ public class DMLTranslator currBuiltinOp = (expr.getDim1()==-1) ? new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NROW, expr) : new LiteralOp(expr.getDim1()); break; - case NCOL: // If the dimensions are available at compile time, then create a LiteralOp (constant propagation) // Else create a UnaryOp so that a control program instruction is generated @@ -2365,7 +2364,7 @@ public class DMLTranslator currBuiltinOp = (expr.getDim1()==-1 || expr.getDim2()==-1) ? new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.LENGTH, expr) : new LiteralOp(expr.getDim1()*expr.getDim2()); break; - + case LIST: currBuiltinOp = new NaryOp(target.getName(), DataType.LIST, ValueType.UNKNOWN, OpOpN.LIST, processAllExpressions(source.getAllExpr(), hops)); http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/main/java/org/apache/sysml/parser/DataExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DataExpression.java b/src/main/java/org/apache/sysml/parser/DataExpression.java index 7ce65a8..7218023 100644 --- a/src/main/java/org/apache/sysml/parser/DataExpression.java +++ b/src/main/java/org/apache/sysml/parser/DataExpression.java @@ -535,8 +535,10 @@ public class DataExpression extends DataIdentifier } //general data expression constant propagation - performConstantPropagationRand( currConstVars ); - performConstantPropagationReadWrite( currConstVars ); + if( !conditional ) { + performConstantPropagationRand( currConstVars ); + performConstantPropagationReadWrite( currConstVars ); + } // check if data parameter of matrix is scalar or matrix -- if scalar, use Rand instead Expression dataParam1 = getVarParam(RAND_DATA); @@ -1591,7 +1593,7 @@ public class DataExpression extends DataIdentifier && currConstVars.containsKey(((DataIdentifier) paramExp).getName())) { addVarParam(paramName, currConstVars.get(((DataIdentifier)paramExp).getName())); - } + } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/main/java/org/apache/sysml/parser/ForStatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ForStatementBlock.java b/src/main/java/org/apache/sysml/parser/ForStatementBlock.java index a29f48e..dd6f7a4 100644 --- a/src/main/java/org/apache/sysml/parser/ForStatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/ForStatementBlock.java @@ -68,7 +68,7 @@ public class ForStatementBlock extends StatementBlock DataIdentifier copyId = new DataIdentifier(origId); origVarsBeforeBody.addVariable(key, copyId); } - + ////////////////////////////////////////////////////////////////////////////// // FIRST PASS: process the predicate / statement blocks in the body of the for statement /////////////////////////////////////////////////////////////////////////////// @@ -83,12 +83,12 @@ public class ForStatementBlock extends StatementBlock //perform constant propagation for ( from, to, incr ) //(e.g., useful for reducing false positives in parfor dependency analysis) - performConstantPropagation(constVars); + if( !conditional ) + performConstantPropagation(constVars); //validate body _dmlProg = dmlProg; - for(StatementBlock sb : body) - { + for(StatementBlock sb : body) { ids = sb.validate(dmlProg, ids, constVars, true); constVars = sb.getConstOut(); } @@ -101,7 +101,7 @@ public class ForStatementBlock extends StatementBlock // for each updated variable boolean revalidationRequired = false; for (String key : _updated.getVariableNames()) - { + { DataIdentifier startVersion = origVarsBeforeBody.getVariable(key); DataIdentifier endVersion = ids.getVariable(key); @@ -110,16 +110,14 @@ public class ForStatementBlock extends StatementBlock //handle data type change (reject) if (!startVersion.getOutput().getDataType().equals(endVersion.getOutput().getDataType())){ raiseValidateError("ForStatementBlock has unsupported conditional data type change of variable '"+key+"' in loop body.", conditional); - } + } //handle size change - long startVersionDim1 = (startVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)startVersion).getOrigDim1() : startVersion.getDim1(); - long endVersionDim1 = (endVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)endVersion).getOrigDim1() : endVersion.getDim1(); - long startVersionDim2 = (startVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)startVersion).getOrigDim2() : startVersion.getDim2(); - long endVersionDim2 = (endVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)endVersion).getOrigDim2() : endVersion.getDim2(); - - boolean sizeUnchanged = ((startVersionDim1 == endVersionDim1) && - (startVersionDim2 == endVersionDim2) ); + long startVersionDim1 = (startVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)startVersion).getOrigDim1() : startVersion.getDim1(); + long endVersionDim1 = (endVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)endVersion).getOrigDim1() : endVersion.getDim1(); + long startVersionDim2 = (startVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)startVersion).getOrigDim2() : startVersion.getDim2(); + long endVersionDim2 = (endVersion instanceof IndexedIdentifier) ? ((IndexedIdentifier)endVersion).getOrigDim2() : endVersion.getDim2(); + boolean sizeUnchanged = ((startVersionDim1 == endVersionDim1) && (startVersionDim2 == endVersionDim2) ); //handle sparsity change //NOTE: nnz not propagated via validate, and hence, we conservatively assume that nnz have been changed. @@ -155,18 +153,18 @@ public class ForStatementBlock extends StatementBlock for( String var : _updated.getVariableNames() ) if( constVars.containsKey( var ) ) constVars.remove( var ); - + //perform constant propagation for ( from, to, incr ) //(e.g., useful for reducing false positives in parfor dependency analysis) - performConstantPropagation(constVars); + if( !conditional ) + performConstantPropagation(constVars); predicate.validateExpression(ids.getVariables(), constVars, conditional); body = fs.getBody(); //validate body _dmlProg = dmlProg; - for(StatementBlock sb : body) - { + for(StatementBlock sb : body) { ids = sb.validate(dmlProg, ids, constVars, true); constVars = sb.getConstOut(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/main/java/org/apache/sysml/parser/FunctionCallIdentifier.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/FunctionCallIdentifier.java b/src/main/java/org/apache/sysml/parser/FunctionCallIdentifier.java index 2f2e618..4ffe582 100644 --- a/src/main/java/org/apache/sysml/parser/FunctionCallIdentifier.java +++ b/src/main/java/org/apache/sysml/parser/FunctionCallIdentifier.java @@ -129,13 +129,15 @@ public class FunctionCallIdentifier extends DataIdentifier } // Step 5: constant propagation into function call statement - for( ParameterExpression paramExpr : _paramExprs ) { - Expression expri = paramExpr.getExpr(); - if( expri instanceof DataIdentifier && !(expri instanceof IndexedIdentifier) - && constVars.containsKey(((DataIdentifier)expri).getName()) ) - { - //replace varname with constant in function call expression - paramExpr.setExpr(constVars.get(((DataIdentifier)expri).getName())); + if( !conditional ) { + for( ParameterExpression paramExpr : _paramExprs ) { + Expression expri = paramExpr.getExpr(); + if( expri instanceof DataIdentifier && !(expri instanceof IndexedIdentifier) + && constVars.containsKey(((DataIdentifier)expri).getName()) ) + { + //replace varname with constant in function call expression + paramExpr.setExpr(constVars.get(((DataIdentifier)expri).getName())); + } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/main/java/org/apache/sysml/parser/IfStatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/IfStatementBlock.java b/src/main/java/org/apache/sysml/parser/IfStatementBlock.java index de4e8a7..74b496d 100644 --- a/src/main/java/org/apache/sysml/parser/IfStatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/IfStatementBlock.java @@ -48,7 +48,7 @@ public class IfStatementBlock extends StatementBlock //validate conditional predicate (incl constant propagation) Expression pred = ifstmt.getConditionalPredicate().getPredicate(); pred.validateExpression(ids.getVariables(), constVars, conditional); - if( pred instanceof DataIdentifier && constVars.containsKey( ((DataIdentifier)pred).getName()) ) { + if( pred instanceof DataIdentifier && constVars.containsKey( ((DataIdentifier)pred).getName()) && !conditional ) { ifstmt.getConditionalPredicate().setPredicate(constVars.get(((DataIdentifier)pred).getName())); } http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/main/java/org/apache/sysml/parser/IndexedIdentifier.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/IndexedIdentifier.java b/src/main/java/org/apache/sysml/parser/IndexedIdentifier.java index 97e94c5..a94ca94 100644 --- a/src/main/java/org/apache/sysml/parser/IndexedIdentifier.java +++ b/src/main/java/org/apache/sysml/parser/IndexedIdentifier.java @@ -117,7 +117,7 @@ public class IndexedIdentifier extends DataIdentifier String identifierName = ((DataIdentifier)_rowLowerBound).getName(); // CASE: rowLowerBound is a constant DataIdentifier - if (currConstVars.containsKey(identifierName)){ + if (currConstVars.containsKey(identifierName) && !conditional){ ConstIdentifier constValue = currConstVars.get(identifierName); if (!(constValue instanceof IntIdentifier || constValue instanceof DoubleIdentifier )) @@ -200,7 +200,7 @@ public class IndexedIdentifier extends DataIdentifier else if (_rowUpperBound != null && _rowUpperBound instanceof DataIdentifier && !(_rowUpperBound instanceof IndexedIdentifier)) { String identifierName = ((DataIdentifier)_rowUpperBound).getName(); - if (currConstVars.containsKey(identifierName)){ + if (currConstVars.containsKey(identifierName) && !conditional){ ConstIdentifier constValue = currConstVars.get(identifierName); if (!(constValue instanceof IntIdentifier || constValue instanceof DoubleIdentifier )) @@ -275,7 +275,7 @@ public class IndexedIdentifier extends DataIdentifier // perform constant propogation for column lower bound else if (_colLowerBound != null && _colLowerBound instanceof DataIdentifier && !(_colLowerBound instanceof IndexedIdentifier)) { String identifierName = ((DataIdentifier)_colLowerBound).getName(); - if (currConstVars.containsKey(identifierName)){ + if (currConstVars.containsKey(identifierName) && !conditional){ ConstIdentifier constValue = currConstVars.get(identifierName); if (!(constValue instanceof IntIdentifier || constValue instanceof DoubleIdentifier )) @@ -362,7 +362,7 @@ public class IndexedIdentifier extends DataIdentifier String identifierName = ((DataIdentifier)_colUpperBound).getName(); - if (currConstVars.containsKey(identifierName)){ + if (currConstVars.containsKey(identifierName) && !conditional){ ConstIdentifier constValue = currConstVars.get(identifierName); if (!(constValue instanceof IntIdentifier || constValue instanceof DoubleIdentifier )) http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/main/java/org/apache/sysml/parser/RelationalExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/RelationalExpression.java b/src/main/java/org/apache/sysml/parser/RelationalExpression.java index ca074cf..eed568c 100644 --- a/src/main/java/org/apache/sysml/parser/RelationalExpression.java +++ b/src/main/java/org/apache/sysml/parser/RelationalExpression.java @@ -126,10 +126,12 @@ public class RelationalExpression extends Expression _right.validateExpression(ids, constVars, conditional); //constant propagation (precondition for more complex constant folding rewrite) - if( _left instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _left).getName()) ) - _left = constVars.get(((DataIdentifier) _left).getName()); - if( _right instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _right).getName()) ) - _right = constVars.get(((DataIdentifier) _right).getName()); + if( !conditional ) { + if( _left instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _left).getName()) ) + _left = constVars.get(((DataIdentifier) _left).getName()); + if( _right instanceof DataIdentifier && constVars.containsKey(((DataIdentifier) _right).getName()) ) + _right = constVars.get(((DataIdentifier) _right).getName()); + } String outputName = getTempName(); DataIdentifier output = new DataIdentifier(outputName); http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/test/java/org/apache/sysml/test/integration/functions/misc/SizePropagationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/SizePropagationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/SizePropagationTest.java index a2714f0..2a526d3 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/SizePropagationTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/SizePropagationTest.java @@ -39,6 +39,7 @@ public class SizePropagationTest extends AutomatedTestBase private static final String TEST_NAME2 = "SizePropagationLoopIx1"; private static final String TEST_NAME3 = "SizePropagationLoopIx2"; private static final String TEST_NAME4 = "SizePropagationLoopIx3"; + private static final String TEST_NAME5 = "SizePropagationLoopIx4"; private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + SizePropagationTest.class.getSimpleName() + "/"; @@ -52,6 +53,7 @@ public class SizePropagationTest extends AutomatedTestBase addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) ); } @Test @@ -94,6 +96,16 @@ public class SizePropagationTest extends AutomatedTestBase testSizePropagation( TEST_NAME4, true, N-1 ); } + @Test + public void testSizePropagationLoopIx4NoRewrites() { + testSizePropagation( TEST_NAME5, false, N ); + } + + @Test + public void testSizePropagationLoopIx4Rewrites() { + testSizePropagation( TEST_NAME5, true, N ); + } + private void testSizePropagation( String testname, boolean rewrites, int expect ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; RUNTIME_PLATFORM oldPlatform = rtplatform; http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/test/java/org/apache/sysml/test/integration/functions/recompile/RandSizeExpressionEvalTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/recompile/RandSizeExpressionEvalTest.java b/src/test/java/org/apache/sysml/test/integration/functions/recompile/RandSizeExpressionEvalTest.java index cb0299c..5708b75 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/recompile/RandSizeExpressionEvalTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/recompile/RandSizeExpressionEvalTest.java @@ -42,38 +42,33 @@ public class RandSizeExpressionEvalTest extends AutomatedTestBase private final static int cols = 14; @Override - public void setUp() - { + public void setUp() { addTestConfiguration( TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{} )); } @Test - public void testComplexRand() - { + public void testComplexRand() { runRandTest(TEST_NAME, false, false); } @Test - public void testComplexRandExprEval() - { + public void testComplexRandExprEval() { runRandTest(TEST_NAME, true, false); } @Test - public void testComplexRandConstFold() - { + public void testComplexRandConstFold() { runRandTest(TEST_NAME, false, true); } private void runRandTest( String testName, boolean evalExpr, boolean constFold ) - { + { boolean oldFlagEval = OptimizerUtils.ALLOW_SIZE_EXPRESSION_EVALUATION; boolean oldFlagFold = OptimizerUtils.ALLOW_CONSTANT_FOLDING; - boolean oldFlagRand1 = OptimizerUtils.ALLOW_RAND_JOB_RECOMPILE; boolean oldFlagRand2 = OptimizerUtils.ALLOW_BRANCH_REMOVAL; boolean oldFlagRand3 = OptimizerUtils.ALLOW_WORSTCASE_SIZE_EXPRESSION_EVALUATION; - + try { TestConfiguration config = getTestConfiguration(testName); @@ -81,10 +76,10 @@ public class RandSizeExpressionEvalTest extends AutomatedTestBase config.addVariable("cols", cols); loadTestConfiguration(config); - /* This is for running the junit test the new way, i.e., construct the arguments directly */ String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testName + ".dml"; - programArgs = new String[]{"-args", Integer.toString(rows), Integer.toString(cols), output("R") }; + programArgs = new String[]{"-explain", "-args", + Integer.toString(rows), Integer.toString(cols), output("R") }; OptimizerUtils.ALLOW_SIZE_EXPRESSION_EVALUATION = evalExpr; OptimizerUtils.ALLOW_CONSTANT_FOLDING = constFold; @@ -93,35 +88,29 @@ public class RandSizeExpressionEvalTest extends AutomatedTestBase OptimizerUtils.ALLOW_RAND_JOB_RECOMPILE = false; OptimizerUtils.ALLOW_BRANCH_REMOVAL = false; OptimizerUtils.ALLOW_WORSTCASE_SIZE_EXPRESSION_EVALUATION = false; - - boolean exceptionExpected = false; - runTest(true, exceptionExpected, null, -1); + + runTest(true, false, null, -1); //check correct propagated size via final results HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); Assert.assertEquals("Unexpected results.", Double.valueOf(rows*cols*3.0), dmlfile.get(new CellIndex(1,1))); //check expected number of compiled and executed MR jobs - if( evalExpr || constFold ) - { - Assert.assertEquals("Unexpected number of executed MR jobs.", - 0, Statistics.getNoOfExecutedMRJobs()); + if( evalExpr || constFold ) { + Assert.assertEquals("Unexpected number of executed MR jobs.", + 0, Statistics.getNoOfExecutedMRJobs()); + } + else { + Assert.assertEquals("Unexpected number of executed MR jobs.", + 2, Statistics.getNoOfExecutedMRJobs()); //Rand, GMR (sum) } - else - { - Assert.assertEquals("Unexpected number of executed MR jobs.", - 2, Statistics.getNoOfExecutedMRJobs()); //Rand, GMR (sum) - } } - finally - { + finally { OptimizerUtils.ALLOW_SIZE_EXPRESSION_EVALUATION = oldFlagEval; OptimizerUtils.ALLOW_CONSTANT_FOLDING = oldFlagFold; - OptimizerUtils.ALLOW_RAND_JOB_RECOMPILE = oldFlagRand1; OptimizerUtils.ALLOW_BRANCH_REMOVAL = oldFlagRand2; OptimizerUtils.ALLOW_WORSTCASE_SIZE_EXPRESSION_EVALUATION = oldFlagRand3; } } - -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/systemml/blob/9861d7a3/src/test/scripts/functions/misc/SizePropagationLoopIx4.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/SizePropagationLoopIx4.dml b/src/test/scripts/functions/misc/SizePropagationLoopIx4.dml new file mode 100644 index 0000000..a9d773e --- /dev/null +++ b/src/test/scripts/functions/misc/SizePropagationLoopIx4.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rand(rows=$1, cols=1); +Y = X +# no loop iterations: n=$1 +for( i in seq(1,0,1) ) { + n1 = nrow(Y); #assign to var + Y = Y[2:n1,] - Y[1:n1-1,]; + #Y = Y[2:nrow(Y),] - Y[1:nrow(Y)-1,]; +} +n = nrow(Y); +R = as.matrix(n); +write(R, $2);
