[SYSTEMML-1543] Codegen multi-aggregates w/ matrix mult root nodes

This patch extends the compilation of codegen multi-aggregate templates
by support for dot products, which is important because sum(X^2) and
sum(X*Y) are rewritten to dot products by dynamic simplification
rewrites. Furthermore, this also includes minor fixes regarding indexing
under sum square operations, avoidance of fusion for 1x1 matrices, and
an unrelated fix for COO sparse blocks which came up after hash function
changes in SYSTEMML-1716.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c4342085
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c4342085
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c4342085

Branch: refs/heads/master
Commit: c43420855d0d768d8826adac455bc03ff673d23e
Parents: 23a164a
Author: Matthias Boehm <[email protected]>
Authored: Sat Jun 17 23:22:37 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sun Jun 18 11:51:02 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       |  5 +-
 .../template/PlanSelectionFuseCostBased.java    | 54 ++++++++++----------
 .../hops/codegen/template/TemplateCell.java     | 12 ++++-
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  7 ++-
 .../runtime/matrix/data/SparseBlockCOO.java     |  2 +-
 .../functions/codegen/MultiAggTmplTest.java     | 18 ++++++-
 .../functions/codegen/multiAggPattern7.R        | 34 ++++++++++++
 .../functions/codegen/multiAggPattern7.dml      | 31 +++++++++++
 8 files changed, 132 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index a58e28d..988af7c 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -56,6 +56,7 @@ import 
org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntrySet;
 import org.apache.sysml.hops.codegen.template.TemplateUtils;
 import org.apache.sysml.hops.recompile.RecompileStatus;
 import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.HopsException;
@@ -631,7 +632,9 @@ public class SpoofCompiler
                                        inHops[0].getRowsInBlock(), 
inHops[0].getColsInBlock(), -1);
                                //inject artificial right indexing operations 
for all parents of all nodes
                                for( int i=0; i<roots.size(); i++ ) {
-                                       Hop hnewi = 
HopRewriteUtils.createScalarIndexing(hnew, 1, i+1);
+                                       Hop hnewi = (roots.get(i) instanceof 
AggUnaryOp) ? 
+                                               
HopRewriteUtils.createScalarIndexing(hnew, 1, i+1) :
+                                               
HopRewriteUtils.createMatrixIndexing(hnew, 1, i+1);
                                        
HopRewriteUtils.rewireAllParentChildReferences(roots.get(i), hnewi);
                                }
                        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
index 0e301e8..e3435e5 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -245,8 +245,7 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                ArrayList<Long> fullAggs = new ArrayList<Long>();
                for( Long hopID : R ) {
                        Hop root = memo._hopRefs.get(hopID);
-                       if( !refHops.contains(hopID) && root instanceof 
AggUnaryOp 
-                               && 
((AggUnaryOp)root).getDirection()==Direction.RowCol)
+                       if( !refHops.contains(hopID) && 
isMultiAggregateRoot(root) )
                                fullAggs.add(hopID);
                }
                if( LOG.isTraceEnabled() ) {
@@ -306,10 +305,19 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                for( Long hopID : fullAggs ) {
                        Hop aggHop = memo._hopRefs.get(hopID);
                        AggregateInfo tmp = new AggregateInfo(aggHop);
-                       for( Hop c : aggHop.getInput() )
+                       for( int i=0; i<aggHop.getInput().size(); i++ ) {
+                               Hop c = 
HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ? 
+                                       
aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
                                rExtractAggregateInfo(memo, c, tmp, 
TemplateType.CellTpl);
-                       if( tmp._fusedInputs.isEmpty() )
-                               
tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
+                       }
+                       if( tmp._fusedInputs.isEmpty() ) {
+                               if( HopRewriteUtils.isMatrixMultiply(aggHop) ) {
+                                       
tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
+                                       
tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
+                               }
+                               else    
+                                       
tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
+                       }
                        aggInfos.add(tmp);      
                }
                
@@ -319,10 +327,9 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                                LOG.trace(info);
                }
                
-               //filter aggregations w/ matmults to ensure consistent dims
                //sort aggregations by num dependencies to simplify merging
                //clusters of aggregations with parallel dependencies
-               aggInfos = aggInfos.stream().filter(a -> !a.containsMatMult)
+               aggInfos = aggInfos.stream()
                        .sorted(Comparator.comparing(a -> a._inputAggs.size()))
                        .collect(Collectors.toList());
                
@@ -366,6 +373,13 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                }
        }
        
+       private static boolean isMultiAggregateRoot(Hop root) {
+               return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, 
AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) 
+                               && 
((AggUnaryOp)root).getDirection()==Direction.RowCol)
+                       || (root instanceof AggBinaryOp && root.getDim1()==1 && 
root.getDim2()==1
+                               && 
HopRewriteUtils.isTransposeOperation(root.getInput().get(0)));
+       }
+       
        private static boolean isValidMultiAggregate(CPlanMemoTable memo, 
MemoTableEntry me) {
                //ensure input consistent sizes (otherwise potential for 
incorrect results)
                boolean ret = true;
@@ -402,11 +416,8 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                        return;
                
                //collect all applicable full aggregations per read
-               if( HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, 
AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX)
-                       && 
((AggUnaryOp)current).getDirection()==Direction.RowCol )
-               {
+               if( isMultiAggregateRoot(current) )
                        aggs.add(current.getHopID());
-               }
                
                //recursively process children
                for( Hop c : current.getInput() )
@@ -417,19 +428,12 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
        
        private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop 
current, AggregateInfo aggInfo, TemplateType type) {
                //collect input aggregates (dependents)
-               if( HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, 
AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX)
-                       && 
((AggUnaryOp)current).getDirection()==Direction.RowCol )
-               {
+               if( isMultiAggregateRoot(current) )
                        aggInfo.addInputAggregate(current.getHopID());
-               }
-               
-               //collect included matrix multiplications
-               if( type != null && HopRewriteUtils.isMatrixMultiply(current) )
-                       aggInfo.setContainsMatMult();
                
                //recursively process children
                MemoTableEntry me = (type!=null) ? 
memo.getBest(current.getHopID()) : null;
-               for( int i=0; i< current.getInput().size(); i++ ) {
+               for( int i=0; i<current.getInput().size(); i++ ) {
                        Hop c = current.getInput().get(i);
                        if( me != null && me.isPlanRef(i) )
                                rExtractAggregateInfo(memo, c, aggInfo, type);
@@ -960,7 +964,6 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                public final HashMap<Long,Hop> _aggregates;
                public final HashSet<Long> _inputAggs = new HashSet<Long>();
                public final HashSet<Long> _fusedInputs = new HashSet<Long>();
-               public boolean containsMatMult = false;
                public AggregateInfo(Hop aggregate) {
                        _aggregates = new HashMap<Long, Hop>();
                        _aggregates.put(aggregate.getHopID(), aggregate);
@@ -971,9 +974,6 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                public void addFusedInput(long hopID) {
                        _fusedInputs.add(hopID);
                }
-               public void setContainsMatMult() {
-                       containsMatMult = true;
-               }
                public boolean isMergable(AggregateInfo that) {
                        //check independence
                        boolean ret = _aggregates.size()<3 
@@ -986,9 +986,11 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                        ret &= !CollectionUtils.intersection(
                                _fusedInputs, that._fusedInputs).isEmpty();
                        //check consistent sizes (result correctness)
+                       Hop in1 = _aggregates.values().iterator().next();
+                       Hop in2 = that._aggregates.values().iterator().next();
                        return ret && HopRewriteUtils.isEqualSize(
-                               
_aggregates.values().iterator().next().getInput().get(0),
-                               
that._aggregates.values().iterator().next().getInput().get(0));
+                               
in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0),
+                               
in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0));
                }
                public AggregateInfo merge(AggregateInfo that) {
                        _aggregates.putAll(that._aggregates);

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 91c61c2..5455775 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -76,6 +76,7 @@ public class TemplateCell extends TemplateBase
        @Override
        public boolean open(Hop hop) {
                return hop.dimsKnown() && isValidOperation(hop)
+                               && !(hop.getDim1()==1 && hop.getDim2()==1)      
                        || (hop instanceof IndexingOp && (((IndexingOp)hop)
                                .isColLowerEqualsUpper() || hop.getDim2()==1));
        }
@@ -162,6 +163,8 @@ public class TemplateCell extends TemplateBase
                        if( me!=null && me.isPlanRef(i) && !(c instanceof 
DataOp)
                                && (me.type!=TemplateType.MultiAggTpl || 
memo.contains(c.getHopID(), TemplateType.CellTpl)))
                                rConstructCplan(c, memo, tmp, inHops, 
compileLiterals);
+                       else if( me!=null && me.type==TemplateType.MultiAggTpl 
&& HopRewriteUtils.isMatrixMultiply(hop) && i==0 )
+                               rConstructCplan(c.getInput().get(0), memo, tmp, 
inHops, compileLiterals);
                        else {
                                CNodeData cdata = 
TemplateUtils.createCNodeData(c, compileLiterals);    
                                tmp.put(c.getHopID(), cdata);
@@ -233,7 +236,10 @@ public class TemplateCell extends TemplateBase
                }
                else if( HopRewriteUtils.isTransposeOperation(hop) ) 
                {
-                       out = tmp.get(hop.getInput().get(0).getHopID());        
+                       out = 
TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), 
+                               hop, tmp, compileLiterals);
+                       if( out instanceof CNodeData && 
!inHops.contains(hop.getInput().get(0)) )
+                               inHops.add(hop.getInput().get(0));
                }
                else if( hop instanceof AggUnaryOp )
                {
@@ -246,11 +252,15 @@ public class TemplateCell extends TemplateBase
                        //(1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
                        if( 
HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), 
hop.getInput().get(1)) ) {
                                CNode cdata1 = 
tmp.get(hop.getInput().get(1).getHopID());
+                               if( TemplateUtils.isColVector(cdata1) )
+                                       cdata1 = new CNodeUnary(cdata1, 
UnaryType.LOOKUP_R);
                                out = new CNodeUnary(cdata1, UnaryType.POW2);
                        }
                        else {
                                CNode cdata1 = 
TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), 
                                                hop.getInput().get(0), tmp, 
compileLiterals);
+                               if( cdata1 instanceof CNodeData && 
!inHops.contains(hop.getInput().get(0).getInput().get(0)) )
+                                       
inHops.add(hop.getInput().get(0).getInput().get(0));
                                if( TemplateUtils.isColVector(cdata1) )
                                        cdata1 = new CNodeUnary(cdata1, 
UnaryType.LOOKUP_R);
                                CNode cdata2 = 
tmp.get(hop.getInput().get(1).getHopID());

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index bec7b38..cf6081b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -532,13 +532,18 @@ public class HopRewriteUtils
        }
        
        public static Hop createScalarIndexing(Hop input, long rix, long cix) {
+               Hop ix = createMatrixIndexing(input, rix, cix);
+               return createUnary(ix, OpOp1.CAST_AS_SCALAR);
+       }
+       
+       public static Hop createMatrixIndexing(Hop input, long rix, long cix) {
                LiteralOp row = new LiteralOp(rix);
                LiteralOp col = new LiteralOp(cix);
                IndexingOp ix = new IndexingOp("tmp", DataType.MATRIX, 
ValueType.DOUBLE, input, row, row, col, col, true, true);
                ix.setOutputBlocksizes(input.getRowsInBlock(), 
input.getColsInBlock());
                copyLineNumbers(input, ix);
                ix.refreshSizeInformation();
-               return createUnary(ix, OpOp1.CAST_AS_SCALAR);
+               return ix;
        }
        
        public static Hop createValueHop( Hop hop, boolean row ) 

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
index 9ca9418..9f527d4 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
@@ -218,7 +218,7 @@ public class SparseBlockCOO extends SparseBlock
        @Override
        public int size(int r) {
                int pos = pos(r);
-               if( _rindexes[pos]!=r )
+               if( pos>=_size || _rindexes[pos]!=r )
                        return 0;
                
                //count number of equal row indexes

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java
index c33d680..07a4396 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java
@@ -42,6 +42,7 @@ public class MultiAggTmplTest extends AutomatedTestBase
        private static final String TEST_NAME4 = TEST_NAME+"4"; //sum(X*Y), 
sum(X^2), sum(Y^2)
        private static final String TEST_NAME5 = TEST_NAME+"5"; //sum(V*X), 
sum(Y*Z), sum(X+Y-Z)
        private static final String TEST_NAME6 = TEST_NAME+"6"; //min(X), 
max(X), sum(X)
+       private static final String TEST_NAME7 = TEST_NAME+"7"; //t(X)%*%X, 
t(X)%*Y, t(Y)%*%Y
        
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
MultiAggTmplTest.class.getSimpleName() + "/";
@@ -53,7 +54,7 @@ public class MultiAggTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for(int i=1; i<=6; i++)
+               for(int i=1; i<=7; i++)
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) 
}) );
        }
        
@@ -147,6 +148,21 @@ public class MultiAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME6, false, ExecType.SPARK );
        }
        
+       @Test   
+       public void testCodegenMultiAggRewrite7CP() {
+               testCodegenIntegration( TEST_NAME7, true, ExecType.CP );
+       }
+
+       @Test   
+       public void testCodegenMultiAgg7CP() {
+               testCodegenIntegration( TEST_NAME7, false, ExecType.CP );
+       }
+       
+       @Test   
+       public void testCodegenMultiAgg7Spark() {
+               testCodegenIntegration( TEST_NAME7, false, ExecType.SPARK );
+       }
+       
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {       
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/test/scripts/functions/codegen/multiAggPattern7.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/multiAggPattern7.R 
b/src/test/scripts/functions/codegen/multiAggPattern7.R
new file mode 100644
index 0000000..b56f090
--- /dev/null
+++ b/src/test/scripts/functions/codegen/multiAggPattern7.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = seq(1,15);
+Y = seq(2,16);
+
+r1 = t(X)%*%X;
+r2 = t(X)%*%Y;
+r3 = t(Y)%*%Y;
+S = r1+r2+r3;
+
+writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/test/scripts/functions/codegen/multiAggPattern7.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/multiAggPattern7.dml 
b/src/test/scripts/functions/codegen/multiAggPattern7.dml
new file mode 100644
index 0000000..3306fd8
--- /dev/null
+++ b/src/test/scripts/functions/codegen/multiAggPattern7.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = seq(1,15);
+Y = seq(2,16);
+if(1==1){}
+
+r1 = t(X)%*%X;
+r2 = t(X)%*%Y;
+r3 = t(Y)%*%Y;
+S = r1+r2+r3;
+
+write(S,$1)

Reply via email to