Repository: systemml
Updated Branches:
  refs/heads/master cad7c1e0f -> cba082eb1


[SYSTEMML-2134] Codegen support for vector ternary axpy in row tmpls

This patch adds the missing codegen support for ternary axpy vector
operations in row templates. So far we only supported the builtin fused
axpy operator as cell operations in cell/row templates. This led to
missed plans and invalid fusion plans for special cases. While the cell
template remains unchanged, in row templates, we now simply compile two
binary operations if row vector intermediates are necessary. This avoids
unnecessary library functions for all combinations of sparse and dense
vectors in these ternary operations. 

For example on PageRank with enabled rewrites, this patch now correctly
compiles the main matrix-vector multiply into the fused operator.
Similarly for MDABivar, it fixes runtime exceptions w/ enabled rewrites.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/cba082eb
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/cba082eb
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/cba082eb

Branch: refs/heads/master
Commit: cba082eb1500fceb46cdd295c3438015c5dbb3a5
Parents: cad7c1e
Author: Matthias Boehm <[email protected]>
Authored: Thu May 31 22:34:21 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu May 31 22:34:21 2018 -0700

----------------------------------------------------------------------
 .../hops/codegen/template/TemplateRow.java      | 15 +++++++--
 .../functions/codegenalg/AlgorithmPageRank.java | 33 ++++++++++++++++++--
 2 files changed, 43 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/cba082eb/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 95be74a..ed68e75 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -53,6 +53,7 @@ import org.apache.sysml.hops.Hop.DataGenMethod;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.OpOpN;
 import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.parser.Statement;
@@ -89,6 +90,7 @@ public class TemplateRow extends TemplateBase
                        || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp)
                                && TemplateCell.isValidOperation(hop) && 
hop.getDim1() > 1)
                        || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+                       || HopRewriteUtils.isTernary(hop, OpOp3.PLUS_MULT, 
OpOp3.MINUS_MULT)
                        || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
                        || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
hop.getDim2()==1 //MV
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
@@ -117,6 +119,7 @@ public class TemplateRow extends TemplateBase
                        || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
                        || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp) 
                                && TemplateCell.isValidOperation(hop))
+                       || HopRewriteUtils.isTernary(hop, OpOp3.PLUS_MULT, 
OpOp3.MINUS_MULT)
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol
                                && HopRewriteUtils.isAggUnaryOp(hop, 
SUPPORTED_ROW_AGG))
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection() == Direction.RowCol 
@@ -449,9 +452,15 @@ public class TemplateRow extends TemplateBase
                        cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, 
hop.getInput().get(0));
                        cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, 
hop.getInput().get(2));
                        
-                       //construct ternary cnode, primitive operation derived 
from OpOp3
-                       out = new CNodeTernary(cdata1, cdata2, cdata3, 
-                               TernaryType.valueOf(top.getOp().toString()));
+                       if( hop.getDim2() > 2 ) { //row vectors
+                               out = new CNodeBinary(cdata1, new 
CNodeBinary(cdata2, cdata3, BinType.VECT_MULT_SCALAR),
+                                       top.getOp()==OpOp3.PLUS_MULT? 
BinType.VECT_PLUS : BinType.VECT_MINUS);
+                       }
+                       else {
+                               //construct scalar ternary cnode, primitive 
operation derived from OpOp3 
+                               out = new CNodeTernary(cdata1, cdata2, cdata3, 
+                                       
TernaryType.valueOf(top.getOp().toString()));
+                       }
                }
                else if(HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
                        CNode[] inputs = new CNode[hop.getInput().size()];

http://git-wip-us.apache.org/repos/asf/systemml/blob/cba082eb/src/test/java/org/apache/sysml/test/integration/functions/codegenalg/AlgorithmPageRank.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegenalg/AlgorithmPageRank.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegenalg/AlgorithmPageRank.java
index 9299a77..5a429f4 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegenalg/AlgorithmPageRank.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegenalg/AlgorithmPageRank.java
@@ -96,6 +96,36 @@ public class AlgorithmPageRank extends AutomatedTestBase
        public void testPageRankSparseCPFuseNoRedundancy() {
                runPageRankTest(TEST_NAME1, true, true, ExecType.CP, 
TestType.FUSE_NO_REDUNDANCY);
        }
+       
+       @Test
+       public void testPageRankDenseCPNoR() {
+               runPageRankTest(TEST_NAME1, false, false, ExecType.CP, 
TestType.DEFAULT);
+       }
+       
+       @Test
+       public void testPageRankSparseCPNoR() {
+               runPageRankTest(TEST_NAME1, false, true, ExecType.CP, 
TestType.DEFAULT);
+       }
+
+       @Test
+       public void testPageRankDenseCPFuseAllNoR() {
+               runPageRankTest(TEST_NAME1, false, false, ExecType.CP, 
TestType.FUSE_ALL);
+       }
+
+       @Test
+       public void testPageRankSparseCPFuseAllNoR() {
+               runPageRankTest(TEST_NAME1, false, true, ExecType.CP, 
TestType.FUSE_ALL);
+       }
+
+       @Test
+       public void testPageRankDenseCPFuseNoRedundancyNoR() {
+               runPageRankTest(TEST_NAME1, false, false, ExecType.CP, 
TestType.FUSE_NO_REDUNDANCY);
+       }
+
+       @Test
+       public void testPageRankSparseCPFuseNoRedundancyNoR() {
+               runPageRankTest(TEST_NAME1, false, true, ExecType.CP, 
TestType.FUSE_NO_REDUNDANCY);
+       }
 
        private void runPageRankTest( String testname, boolean rewrites, 
boolean sparse, ExecType instType, TestType testType)
        {
@@ -125,8 +155,7 @@ public class AlgorithmPageRank extends AutomatedTestBase
                                String.valueOf(maxiter), expectedDir());
 
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
-                       //TODO test both with and without operator fusion
-                       OptimizerUtils.ALLOW_OPERATOR_FUSION = false;
+                       OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
                        
                        //generate actual datasets
                        double[][] G = getRandomMatrix(rows, cols, 1, 1, 
sparse?sparsity2:sparsity1, 234);

Reply via email to