Repository: systemml
Updated Branches:
  refs/heads/master 244f8049d -> d71a7d359


[SYSTEMML-1725] Extended codegen row-wise template (full rc aggregates)

This patch extends the code generator row-wise template by support for
full row/column aggregations (in addition to the types no_agg, row_agg,
col_agg). The motivation for this are scenarios where we currently
compile a sequence of row-wise and cell-wise templates, because the
patterns contain both intermediate row aggregates and a final full
aggregation. 

For example, on Mlogreg over an 100Mx10 input, this patch improved
end-to-end performance with codegen from 248s with (w/ 59/29/1
bufferpool writes) to 193s (w/ 45/15/1 bufferpool writes). In
comparison, without codegen but existing fused operators, the end-to-end
performance was 512s (w/ 234/30/1 writes).

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d71a7d35
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d71a7d35
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d71a7d35

Branch: refs/heads/master
Commit: d71a7d359ba1d776feff8447e02646a2623a41af
Parents: 244f804
Author: Matthias Boehm <[email protected]>
Authored: Tue Jun 20 12:59:47 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Tue Jun 20 13:00:00 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeRow.java      | 10 ++++---
 .../hops/codegen/template/TemplateRow.java      | 12 ++++++--
 .../hops/codegen/template/TemplateUtils.java    |  3 ++
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  5 +---
 .../sysml/runtime/codegen/SpoofRowwise.java     | 20 +++++++++++--
 .../instructions/spark/SpoofSPInstruction.java  | 16 ++++++----
 .../functions/codegen/RowAggTmplTest.java       | 18 +++++++++++-
 .../scripts/functions/codegen/rowAggPattern21.R | 31 ++++++++++++++++++++
 .../functions/codegen/rowAggPattern21.dml       | 26 ++++++++++++++++
 9 files changed, 122 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d71a7d35/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
index b2b16f2..546bf60 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
@@ -48,8 +48,9 @@ public class CNodeRow extends CNodeTpl
                        + "  }\n"                       
                        + "}\n";
 
-       private static final String TEMPLATE_ROWAGG_OUT = "    c[rowIndex] = 
%IN%;\n";
-       private static final String TEMPLATE_NOAGG_OUT = "    
LibSpoofPrimitives.vectWrite(%IN%, c, rowIndex*len, len);\n";
+       private static final String TEMPLATE_ROWAGG_OUT  = "    c[rowIndex] = 
%IN%;\n";
+       private static final String TEMPLATE_FULLAGG_OUT = "    c[0] += 
%IN%;\n";
+       private static final String TEMPLATE_NOAGG_OUT   = "    
LibSpoofPrimitives.vectWrite(%IN%, c, rowIndex*len, len);\n";
        
        public CNodeRow(ArrayList<CNode> inputs, CNode output ) {
                super(inputs, output);
@@ -114,8 +115,8 @@ public class CNodeRow extends CNodeTpl
        
        private String getOutputStatement(String varName) {
                if( !_type.isColumnAgg() ) {
-                       String tmp = (_type==RowType.NO_AGG) ?
-                               TEMPLATE_NOAGG_OUT : TEMPLATE_ROWAGG_OUT;
+                       String tmp = (_type==RowType.NO_AGG) ? 
TEMPLATE_NOAGG_OUT : 
+                               (_type==RowType.FULL_AGG) ? 
TEMPLATE_FULLAGG_OUT : TEMPLATE_ROWAGG_OUT;
                        return tmp.replace("%IN%", varName);
                }
                return "";
@@ -131,6 +132,7 @@ public class CNodeRow extends CNodeTpl
        public SpoofOutputDimsType getOutputDimType() {
                switch( _type ) {
                        case NO_AGG: return SpoofOutputDimsType.INPUT_DIMS;
+                       case FULL_AGG: return SpoofOutputDimsType.SCALAR;
                        case ROW_AGG: return TemplateUtils.isUnary(_output, 
UnaryType.CBIND0) ?
                                SpoofOutputDimsType.ROW_DIMS2 : 
SpoofOutputDimsType.ROW_DIMS;
                        case COL_AGG: return 
SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector

http://git-wip-us.apache.org/repos/asf/systemml/blob/d71a7d35/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index ca10569..71091cf 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -96,6 +96,8 @@ public class TemplateRow extends TemplateBase
                                        && TemplateCell.isValidOperation(hop))  
        
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol
                                && HopRewriteUtils.isAggUnaryOp(hop, 
SUPPORTED_ROW_AGG))
+                       || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection() == Direction.RowCol 
+                               && ((AggUnaryOp)hop).getOp() == AggOp.SUM )
                        || (hop instanceof AggBinaryOp && hop.getDim1()>1 && 
hop.getDim2()==1
                                && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
        }
@@ -113,8 +115,8 @@ public class TemplateRow extends TemplateBase
 
        @Override
        public CloseType close(Hop hop) {
-               //close on column aggregate (e.g., colSums, t(X)%*%y)
-               if(    (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()==Direction.Col)
+               //close on column or full aggregate (e.g., colSums, t(X)%*%y)
+               if(    (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.Row)
                        || (hop instanceof AggBinaryOp && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))
                        || HopRewriteUtils.isBinary(hop, OpOp2.CBIND) )
                        return CloseType.CLOSED_VALID;
@@ -188,7 +190,7 @@ public class TemplateRow extends TemplateBase
                                                inHops2.put("X", 
hop.getInput().get(0));
                                }
                        }
-                       else  if (((AggUnaryOp)hop).getDirection() == 
Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) {
+                       else if (((AggUnaryOp)hop).getDirection() == 
Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) {
                                //vector add without temporary copy
                                if( cdata1 instanceof CNodeBinary && 
((CNodeBinary)cdata1).getType().isVectorScalarPrimitive() )
                                        out = new 
CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), 
@@ -196,6 +198,10 @@ public class TemplateRow extends TemplateBase
                                else    
                                        out = cdata1;
                        }
+                       else if( ((AggUnaryOp)hop).getDirection() == 
Direction.RowCol && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) {
+                               out = (cdata1.getDataType().isMatrix()) ?
+                                       new CNodeUnary(cdata1, 
UnaryType.ROW_SUMS) : cdata1;
+                       }
                }
                else if(hop instanceof AggBinaryOp)
                {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d71a7d35/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index fbaaab6..6111e9d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -250,6 +250,9 @@ public class TemplateUtils
                        && !(output instanceof AggBinaryOp && HopRewriteUtils
                                
.isTransposeOfItself(output.getInput().get(0),input)))
                        return RowType.ROW_AGG;
+               else if( output instanceof AggUnaryOp 
+                       && 
((AggUnaryOp)output).getDirection()==Direction.RowCol )
+                       return RowType.FULL_AGG;
                else if( output.getDim1()==input.getDim2() && 
output.getDim2()==1 )
                        return RowType.COL_AGG_T;
                else

http://git-wip-us.apache.org/repos/asf/systemml/blob/d71a7d35/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index b98901a..582809e 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -854,10 +854,7 @@ public class HopRewriteUtils
                if( !(hop instanceof AggUnaryOp) )
                        return false;
                AggOp hopOp = ((AggUnaryOp)hop).getOp();
-               for( AggOp opi : op ) 
-                       if( hopOp == opi )
-                               return true;
-               return false; 
+               return ArrayUtils.contains(op, hopOp);
        }
        
        public static boolean isSum(Hop hop) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d71a7d35/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java 
b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
index 529b838..104a1bf 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
@@ -29,6 +29,7 @@ import java.util.concurrent.Future;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysml.runtime.instructions.cp.DoubleObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
 import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -45,6 +46,7 @@ public abstract class SpoofRowwise extends SpoofOperator
        
        public enum RowType {
                NO_AGG,    //no aggregation
+               FULL_AGG,  //full row/col aggregation
                ROW_AGG,   //row aggregation (e.g., rowSums() or X %*% v)
                COL_AGG,   //col aggregation (e.g., colSums() or t(y) %*% X)
                COL_AGG_T; //transposed col aggregation (e.g., t(X) %*% y)
@@ -82,6 +84,18 @@ public abstract class SpoofRowwise extends SpoofOperator
        }
        
        @Override
+       public ScalarObject execute(ArrayList<MatrixBlock> inputs, 
ArrayList<ScalarObject> scalarObjects, int k) 
+               throws DMLRuntimeException 
+       {
+               MatrixBlock out = new MatrixBlock(1, 1, false);
+               if( k > 1 )
+                       execute(inputs, scalarObjects, out, k);
+               else
+                       execute(inputs, scalarObjects, out);
+               return new DoubleObject(out.quickGetValue(0, 0));
+       }
+       
+       @Override
        public void execute(ArrayList<MatrixBlock> inputs, 
ArrayList<ScalarObject> scalarObjects, MatrixBlock out)      
                throws DMLRuntimeException 
        {
@@ -155,15 +169,16 @@ public abstract class SpoofRowwise extends SpoofOperator
                int blklen = (int)(Math.ceil((double)m/nk));
                try
                {
-                       if( _type.isColumnAgg() ) {
+                       if( _type.isColumnAgg() || _type == RowType.FULL_AGG ) {
                                //execute tasks
                                ArrayList<ParColAggTask> tasks = new 
ArrayList<ParColAggTask>();
                                for( int i=0; i<nk & i*blklen<m; i++ )
                                        tasks.add(new 
ParColAggTask(inputs.get(0), b, scalars, n, i*blklen, Math.min((i+1)*blklen, 
m)));
                                List<Future<double[]>> taskret = 
pool.invokeAll(tasks); 
                                //aggregate partial results
+                               int len = _type.isColumnAgg() ? n : 1;
                                for( Future<double[]> task : taskret )
-                                       LibMatrixMult.vectAdd(task.get(), 
out.getDenseBlock(), 0, 0, n);
+                                       LibMatrixMult.vectAdd(task.get(), 
out.getDenseBlock(), 0, 0, len);
                                out.recomputeNonZeros();
                        }
                        else {
@@ -190,6 +205,7 @@ public abstract class SpoofRowwise extends SpoofOperator
        private void allocateOutputMatrix(int m, int n, MatrixBlock out) {
                switch( _type ) {
                        case NO_AGG: out.reset(m, n, false); break;
+                       case FULL_AGG: out.reset(1, 1, false); break;
                        case ROW_AGG: out.reset(m, 1+(_cbind0?1:0), false); 
break;
                        case COL_AGG: out.reset(1, n, false); break;
                        case COL_AGG_T: out.reset(n, 1, false); break;

http://git-wip-us.apache.org/repos/asf/systemml/blob/d71a7d35/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
index f11b3d0..622944d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
@@ -200,11 +200,16 @@ public class SpoofSPInstruction extends SPInstruction
                else if( _class.getSuperclass() == SpoofRowwise.class ) { //row 
aggregate operator
                        SpoofRowwise op = (SpoofRowwise) 
CodegenUtils.createInstance(_class);   
                        RowwiseFunction fmmc = new 
RowwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars, 
(int)mcIn.getCols());
-                       out = in.mapPartitionsToPair(fmmc, 
op.getRowType()==RowType.ROW_AGG);
+                       out = in.mapPartitionsToPair(fmmc, 
op.getRowType()==RowType.ROW_AGG
+                                       || op.getRowType() == RowType.NO_AGG);
                        
-                       if( op.getRowType().isColumnAgg() ) {
-                               MatrixBlock tmpMB = 
RDDAggregateUtils.sumStable(out);           
-                               sec.setMatrixOutput(_out.getName(), tmpMB);
+                       if( op.getRowType().isColumnAgg() || 
op.getRowType()==RowType.FULL_AGG ) {
+                               MatrixBlock tmpMB = 
RDDAggregateUtils.sumStable(out);
+                               if( op.getRowType().isColumnAgg() )
+                                       sec.setMatrixOutput(_out.getName(), 
tmpMB);
+                               else
+                                       sec.setScalarOutput(_out.getName(), 
+                                               new 
DoubleObject(tmpMB.quickGetValue(0, 0)));
                        }
                        else //row-agg or no-agg 
                        {
@@ -311,7 +316,8 @@ public class SpoofSPInstruction extends SPInstruction
                        
LibSpoofPrimitives.setupThreadLocalMemory(_op.getNumIntermediates(), _clen);
                        
                        ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new 
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();
-                       boolean aggIncr = _op.getRowType().isColumnAgg(); 
//aggregate entire partition to avoid allocations
+                       boolean aggIncr = (_op.getRowType().isColumnAgg() 
//aggregate entire partition
+                               || _op.getRowType() == RowType.FULL_AGG); 
                        MatrixBlock blkOut = aggIncr ? new MatrixBlock() : null;
                        
                        while( arg.hasNext() ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d71a7d35/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index 809b812..f867e18 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -56,6 +56,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        private static final String TEST_NAME18 = TEST_NAME+"18"; //MLogreg - 
matrix-vector cbind 0s
        private static final String TEST_NAME19 = TEST_NAME+"19"; //MLogreg - 
rowwise dag
        private static final String TEST_NAME20 = TEST_NAME+"20"; //1 / (1 - (A 
/ rowSums(A)))
+       private static final String TEST_NAME21 = TEST_NAME+"21"; 
//sum(X/rowSums(X))
        
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RowAggTmplTest.class.getSimpleName() + "/";
@@ -67,7 +68,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for(int i=1; i<=20; i++)
+               for(int i=1; i<=21; i++)
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) 
}) );
        }
        
@@ -371,6 +372,21 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME20, false, ExecType.SPARK );
        }
        
+       @Test   
+       public void testCodegenRowAggRewrite21CP() {
+               testCodegenIntegration( TEST_NAME21, true, ExecType.CP );
+       }
+       
+       @Test
+       public void testCodegenRowAgg21CP() {
+               testCodegenIntegration( TEST_NAME21, false, ExecType.CP );
+       }
+       
+       @Test
+       public void testCodegenRowAgg21SP() {
+               testCodegenIntegration( TEST_NAME21, false, ExecType.SPARK );
+       }
+       
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {       
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;

http://git-wip-us.apache.org/repos/asf/systemml/blob/d71a7d35/src/test/scripts/functions/codegen/rowAggPattern21.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern21.R 
b/src/test/scripts/functions/codegen/rowAggPattern21.R
new file mode 100644
index 0000000..fda9934
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern21.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+A = matrix(1, 1500, 7);
+  
+R = as.matrix(sum(A / rowSums(A)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/d71a7d35/src/test/scripts/functions/codegen/rowAggPattern21.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern21.dml 
b/src/test/scripts/functions/codegen/rowAggPattern21.dml
new file mode 100644
index 0000000..b671f88
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern21.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = matrix(1, 1500, 7);
+
+R = as.matrix(sum(A / rowSums(A)));
+
+write(R, $1)

Reply via email to