Repository: systemml
Updated Branches:
  refs/heads/master ae4c00682 -> 1adfc7266


[SYSTEMML-1848] Performance codegen multi aggregates over sparse inputs

This patch improves the performance of codegen multi-aggregates by
determining sparse-safe multi-aggregates and the exploitation of
sparse-safe operations over sparse inputs during runtime. On a scenario
of 20 iterations of sum(X*Y) and sum(Y*Z), where all matrices are of
size 1M x 1K, with sparsity 0.1, this patch improved the runtime from
90.8s to 7.1s (compared to the baselines w/o and w/ fused operators of
129s and 12.4s, respectively).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/585afa20
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/585afa20
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/585afa20

Branch: refs/heads/master
Commit: 585afa205b21c73b6278f30d8b37681cb80fae88
Parents: ae4c006
Author: Matthias Boehm <[email protected]>
Authored: Thu Aug 17 18:06:00 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Aug 18 14:15:41 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeCell.java     |  2 +-
 .../sysml/hops/codegen/cplan/CNodeMultiAgg.java | 15 +++++-
 .../hops/codegen/template/TemplateCell.java     | 22 ++++++--
 .../hops/codegen/template/TemplateMultiAgg.java |  2 +
 .../hops/codegen/template/TemplateUtils.java    |  9 ++++
 .../sysml/runtime/codegen/SpoofCellwise.java    |  2 +-
 .../runtime/codegen/SpoofMultiAggregate.java    | 55 ++++++++++++++++----
 7 files changed, 89 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/585afa20/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
index dd3806d..25c422d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java
@@ -39,7 +39,7 @@ public class CNodeCell extends CNodeTpl
                        + "\n"
                        + "public final class %TMP% extends SpoofCellwise {\n" 
                        + "  public %TMP%() {\n"
-                       + "    super(CellType.%TYPE%, %AGG_OP%, 
%SPARSE_SAFE%);\n"
+                       + "    super(CellType.%TYPE%, %SPARSE_SAFE%, 
%AGG_OP%);\n"
                        + "  }\n"
                        + "  protected double genexec(double a, SideInput[] b, 
double[] scalars, int m, int n, int rowIndex, int colIndex) { \n"
                        + "%BODY_dense%"

http://git-wip-us.apache.org/repos/asf/systemml/blob/585afa20/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
index b6b3a80..8abf907 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
@@ -41,7 +41,7 @@ public class CNodeMultiAgg extends CNodeTpl
                        + "\n"
                        + "public final class %TMP% extends SpoofMultiAggregate 
{ \n"
                        + "  public %TMP%() {\n"
-                       + "    super(%AGG_OP%);\n"
+                       + "    super(%SPARSE_SAFE%, %AGG_OP%);\n"
                        + "  }\n"
                        + "  protected void genexec(double a, SideInput[] b, 
double[] scalars, double[] c, "
                                        + "int m, int n, int rowIndex, int 
colIndex) { \n"
@@ -56,6 +56,7 @@ public class CNodeMultiAgg extends CNodeTpl
        private ArrayList<CNode> _outputs = null; 
        private ArrayList<AggOp> _aggOps = null;
        private ArrayList<Hop> _roots = null;
+       private boolean _sparseSafe = false;
        
        public CNodeMultiAgg(ArrayList<CNode> inputs, ArrayList<CNode> outputs) 
{
                super(inputs, null);
@@ -89,6 +90,14 @@ public class CNodeMultiAgg extends CNodeTpl
                return _roots;
        }
        
+       public void setSparseSafe(boolean flag) {
+               _sparseSafe = flag;
+       }
+       
+       public boolean isSparseSafe() {
+               return _sparseSafe;
+       }
+       
        @Override
        public void renameInputs() {
                rRenameDataNode(_outputs, _inputs.get(0), "a"); // input matrix
@@ -130,7 +139,9 @@ public class CNodeMultiAgg extends CNodeTpl
                        aggList += "AggOp."+aggOp.name();
                }
                tmp = tmp.replace("%AGG_OP%", aggList);
-
+               tmp = tmp.replace("%SPARSE_SAFE%",
+                       String.valueOf(isSparseSafe()));
+               
                return tmp;
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/585afa20/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 65bad08..ad4589b 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -20,9 +20,11 @@
 package org.apache.sysml.hops.codegen.template;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
@@ -134,9 +136,8 @@ public class TemplateCell extends TemplateBase
                CNodeCell tpl = new CNodeCell(inputs, output);
                tpl.setCellType(TemplateUtils.getCellType(hop));
                tpl.setAggOp(TemplateUtils.getAggOp(hop));
-               tpl.setSparseSafe((HopRewriteUtils.isBinary(hop, OpOp2.MULT) && 
hop.getInput().contains(sinHops[0]))
-                               || (HopRewriteUtils.isBinary(hop, OpOp2.DIV) && 
hop.getInput().get(0) == sinHops[0])
-                               || TemplateUtils.rIsBinaryOnly(tpl.getOutput(), 
BinType.MULT));
+               tpl.setSparseSafe(isSparseSafe(Arrays.asList(hop), sinHops[0], 
+                       Arrays.asList(tpl.getOutput()), 
Arrays.asList(tpl.getAggOp()), false));
                tpl.setRequiresCastDtm(hop instanceof AggBinaryOp);
                tpl.setBeginLine(hop.getBeginLine());
                
@@ -315,6 +316,21 @@ public class TemplateCell extends TemplateBase
                                || (hop instanceof ParameterizedBuiltinOp && 
((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE));   
        }
        
+       protected boolean isSparseSafe(List<Hop> roots, Hop mainInput, 
List<CNode> outputs, List<AggOp> aggOps, boolean onlySum) {
+               boolean ret = true;
+               for( int i=0; i<outputs.size() && ret; i++ ) {
+                       ret &= (HopRewriteUtils.isBinary(roots.get(i), 
OpOp2.MULT) 
+                                       && 
roots.get(i).getInput().contains(mainInput))
+                               || (HopRewriteUtils.isBinary(roots.get(i), 
OpOp2.DIV) 
+                                       && roots.get(i).getInput().get(0) == 
mainInput)
+                               || (TemplateUtils.rIsBinaryOnly(outputs.get(i), 
BinType.MULT)
+                                       && 
TemplateUtils.rContainsInput(outputs.get(i), mainInput.getHopID()));
+                       if( onlySum )
+                               ret &= (aggOps.get(i)==AggOp.SUM || 
aggOps.get(i)==AggOp.SUM_SQ);
+               }
+               return ret;
+       }
+       
        /**
         * Comparator to order input hops of the cell template. We try to order 
         * matrices-vectors-scalars via sorting by number of cells and for 

http://git-wip-us.apache.org/repos/asf/systemml/blob/585afa20/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java
index bc51cf0..0c2886e 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java
@@ -107,6 +107,8 @@ public class TemplateMultiAgg extends TemplateCell
                }
                CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs);
                tpl.setAggOps(aggOps);
+               tpl.setSparseSafe(isSparseSafe(roots, sinHops[0], 
+                       tpl.getOutputs(), tpl.getAggOps(), true));
                tpl.setRootNodes(roots);
                tpl.setBeginLine(hop.getBeginLine());
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/585afa20/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 5b739ee..55f6fee 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -257,6 +257,15 @@ public class TemplateUtils
                return ret;
        }
        
+       public static boolean rContainsInput(CNode node, long hopID) {
+               boolean ret = false;
+               for( CNode c : node.getInput() )
+                       ret |= rContainsInput(c, hopID);
+               if( node instanceof CNodeData )
+                       ret |= (((CNodeData)node).getHopID()==hopID);
+               return ret;
+       }
+       
        public static boolean isTernary(CNode node, TernaryType...types) {
                return node instanceof CNodeTernary
                        && ArrayUtils.contains(types, 
((CNodeTernary)node).getType());

http://git-wip-us.apache.org/repos/asf/systemml/blob/585afa20/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java 
b/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java
index 63168e6..c35695f 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofCellwise.java
@@ -72,7 +72,7 @@ public abstract class SpoofCellwise extends SpoofOperator 
implements Serializabl
        private final AggOp _aggOp;
        private final boolean _sparseSafe;
        
-       public SpoofCellwise(CellType type, AggOp aggOp, boolean sparseSafe) {
+       public SpoofCellwise(CellType type, boolean sparseSafe, AggOp aggOp) {
                _type = type;
                _aggOp = aggOp;
                _sparseSafe = sparseSafe;

http://git-wip-us.apache.org/repos/asf/systemml/blob/585afa20/src/main/java/org/apache/sysml/runtime/codegen/SpoofMultiAggregate.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/codegen/SpoofMultiAggregate.java 
b/src/main/java/org/apache/sysml/runtime/codegen/SpoofMultiAggregate.java
index 679c964..43811f2 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofMultiAggregate.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofMultiAggregate.java
@@ -51,8 +51,10 @@ public abstract class SpoofMultiAggregate extends 
SpoofOperator implements Seria
        private static final long PAR_NUMCELL_THRESHOLD = 1024*1024;   //Min 1M 
elements
        
        private final AggOp[] _aggOps;
+       private final boolean _sparseSafe;
        
-       public SpoofMultiAggregate(AggOp... aggOps) {
+       public SpoofMultiAggregate(boolean sparseSafe, AggOp... aggOps) {
+               _sparseSafe = sparseSafe;
                _aggOps = aggOps;
        }
        
@@ -60,6 +62,10 @@ public abstract class SpoofMultiAggregate extends 
SpoofOperator implements Seria
                return _aggOps;
        }
        
+       public boolean isSparseSafe() {
+               return _sparseSafe;
+       }
+       
        @Override
        public String getSpoofType() {
                return "MA" +  getClass().getName().split("\\.")[1];
@@ -95,6 +101,7 @@ public abstract class SpoofMultiAggregate extends 
SpoofOperator implements Seria
                double[] scalars = prepInputScalars(scalarObjects);
                final int m = inputs.get(0).getNumRows();
                final int n = inputs.get(0).getNumColumns();
+               boolean sparseSafe = isSparseSafe();
                
                if( k <= 1 ) //SINGLE-THREADED
                {
@@ -103,7 +110,7 @@ public abstract class SpoofMultiAggregate extends 
SpoofOperator implements Seria
                        else if( !inputs.get(0).isInSparseFormat() )
                                executeDense(inputs.get(0).getDenseBlock(), b, 
scalars, c, m, n, 0, m);
                        else    
-                               executeSparse(inputs.get(0).getSparseBlock(), 
b, scalars, c, m, n, 0, m);
+                               executeSparse(inputs.get(0).getSparseBlock(), 
b, scalars, c, m, n, sparseSafe, 0, m);
                }
                else  //MULTI-THREADED
                {
@@ -113,7 +120,8 @@ public abstract class SpoofMultiAggregate extends 
SpoofOperator implements Seria
                                int nk = 
UtilFunctions.roundToNext(Math.min(8*k,m/32), k);
                                int blklen = (int)(Math.ceil((double)m/nk));
                                for( int i=0; i<nk & i*blklen<m; i++ )
-                                       tasks.add(new ParAggTask(inputs.get(0), 
b, scalars, m, n, i*blklen, Math.min((i+1)*blklen, m))); 
+                                       tasks.add(new ParAggTask(inputs.get(0), 
b, scalars,
+                                               m, n, sparseSafe, i*blklen, 
Math.min((i+1)*blklen, m))); 
                                //execute tasks
                                List<Future<double[]>> taskret = 
pool.invokeAll(tasks); 
                                pool.shutdown();
@@ -147,17 +155,40 @@ public abstract class SpoofMultiAggregate extends 
SpoofOperator implements Seria
                }
        }
        
-       private void executeSparse(SparseBlock sblock, SideInput[] b, double[] 
scalars, double[] c, int m, int n, int rl, int ru) 
+       private void executeSparse(SparseBlock sblock, SideInput[] b, double[] 
scalars,
+                       double[] c, int m, int n, boolean sparseSafe, int rl, 
int ru) 
                throws DMLRuntimeException 
        {
+               if( sblock == null && sparseSafe )
+                       return;
+               
                SideInput[] lb = createSparseSideInputs(b);
                
-               //core dense aggregation operation
-               for( int i=rl; i<ru; i++ )
-                       for( int j=0; j<n; j++ ) {
-                               double in = (sblock != null) ? sblock.get(i, j) 
: 0;
-                               genexec( in, lb, scalars, c, m, n, i, j );
+               //note: sequential scan algorithm for both sparse-safe and 
-unsafe
+               //in order to avoid binary search for sparse-unsafe
+               for(int i=rl; i<ru; i++) {
+                       int lastj = -1;
+                       //handle non-empty rows
+                       if( sblock != null && !sblock.isEmpty(i) ) {
+                               int apos = sblock.pos(i);
+                               int alen = sblock.size(i);
+                               int[] aix = sblock.indexes(i);
+                               double[] avals = sblock.values(i);
+                               for(int k=apos; k<apos+alen; k++) {
+                                       //process zeros before current non-zero
+                                       if( !sparseSafe )
+                                               for(int j=lastj+1; j<aix[k]; 
j++)
+                                                       genexec(0, lb, scalars, 
c, m, n, i, j);
+                                       //process current non-zero
+                                       lastj = aix[k];
+                                       genexec(avals[k], lb, scalars, c, m, n, 
i, lastj);
+                               }
                        }
+                       //process empty rows or remaining zeros
+                       if( !sparseSafe )
+                               for(int j=lastj+1; j<n; j++)
+                                       genexec(0, lb, scalars, c, m, n, i, j);
+               }
        }
 
        private void executeCompressed(CompressedMatrixBlock a, SideInput[] b, 
double[] scalars, double[] c, int m, int n, int rl, int ru) throws 
DMLRuntimeException 
@@ -251,16 +282,18 @@ public abstract class SpoofMultiAggregate extends 
SpoofOperator implements Seria
                private final double[] _scalars;
                private final int _rlen;
                private final int _clen;
+               private final boolean _safe;
                private final int _rl;
                private final int _ru;
 
                protected ParAggTask( MatrixBlock a, SideInput[] b, double[] 
scalars, 
-                               int rlen, int clen, int rl, int ru ) {
+                               int rlen, int clen, boolean safe, int rl, int 
ru ) {
                        _a = a;
                        _b = b;
                        _scalars = scalars;
                        _rlen = rlen;
                        _clen = clen;
+                       _safe = safe;
                        _rl = rl;
                        _ru = ru;
                }
@@ -274,7 +307,7 @@ public abstract class SpoofMultiAggregate extends 
SpoofOperator implements Seria
                        else if( !_a.isInSparseFormat() )
                                executeDense(_a.getDenseBlock(), _b, _scalars, 
c, _rlen, _clen, _rl, _ru);
                        else    
-                               executeSparse(_a.getSparseBlock(), _b, 
_scalars, c, _rlen, _clen, _rl, _ru);
+                               executeSparse(_a.getSparseBlock(), _b, 
_scalars, c, _rlen, _clen, _safe, _rl, _ru);
                        return c;
                }
        }

Reply via email to