[SYSTEMML-1543] Codegen multi-aggregates w/ matrix mult root nodes This patch extends the compilation of codegen multi-aggregate templates by support for dot products, which is important because sum(X^2) and sum(X*Y) are rewritten to dot products by dynamic simplification rewrites. Furthermore, this also includes minor fixes regarding indexing under sum square operations, avoidance of fusion for 1x1 matrices, and an unrelated fix for COO sparse blocks which came up after hash function changes in SYSTEMML-1716.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c4342085 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c4342085 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c4342085 Branch: refs/heads/master Commit: c43420855d0d768d8826adac455bc03ff673d23e Parents: 23a164a Author: Matthias Boehm <[email protected]> Authored: Sat Jun 17 23:22:37 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sun Jun 18 11:51:02 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 5 +- .../template/PlanSelectionFuseCostBased.java | 54 ++++++++++---------- .../hops/codegen/template/TemplateCell.java | 12 ++++- .../sysml/hops/rewrite/HopRewriteUtils.java | 7 ++- .../runtime/matrix/data/SparseBlockCOO.java | 2 +- .../functions/codegen/MultiAggTmplTest.java | 18 ++++++- .../functions/codegen/multiAggPattern7.R | 34 ++++++++++++ .../functions/codegen/multiAggPattern7.dml | 31 +++++++++++ 8 files changed, 132 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index a58e28d..988af7c 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -56,6 +56,7 @@ import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntrySet; import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.hops.recompile.RecompileStatus; import org.apache.sysml.hops.recompile.Recompiler; +import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.HopsException; @@ -631,7 +632,9 @@ public class SpoofCompiler inHops[0].getRowsInBlock(), inHops[0].getColsInBlock(), -1); //inject artificial right indexing operations for all parents of all nodes for( int i=0; i<roots.size(); i++ ) { - Hop hnewi = HopRewriteUtils.createScalarIndexing(hnew, 1, i+1); + Hop hnewi = (roots.get(i) instanceof AggUnaryOp) ? + HopRewriteUtils.createScalarIndexing(hnew, 1, i+1) : + HopRewriteUtils.createMatrixIndexing(hnew, 1, i+1); HopRewriteUtils.rewireAllParentChildReferences(roots.get(i), hnewi); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java index 0e301e8..e3435e5 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java @@ -245,8 +245,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection ArrayList<Long> fullAggs = new ArrayList<Long>(); for( Long hopID : R ) { Hop root = memo._hopRefs.get(hopID); - if( !refHops.contains(hopID) && root instanceof AggUnaryOp - && ((AggUnaryOp)root).getDirection()==Direction.RowCol) + if( !refHops.contains(hopID) && isMultiAggregateRoot(root) ) fullAggs.add(hopID); } if( LOG.isTraceEnabled() ) { @@ -306,10 +305,19 @@ public class PlanSelectionFuseCostBased extends PlanSelection for( Long hopID : fullAggs ) { Hop aggHop = memo._hopRefs.get(hopID); AggregateInfo tmp = new AggregateInfo(aggHop); - for( Hop c : aggHop.getInput() ) + for( int i=0; i<aggHop.getInput().size(); i++ ) { + Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ? + aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i); rExtractAggregateInfo(memo, c, tmp, TemplateType.CellTpl); - if( tmp._fusedInputs.isEmpty() ) - tmp.addFusedInput(aggHop.getInput().get(0).getHopID()); + } + if( tmp._fusedInputs.isEmpty() ) { + if( HopRewriteUtils.isMatrixMultiply(aggHop) ) { + tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID()); + tmp.addFusedInput(aggHop.getInput().get(1).getHopID()); + } + else + tmp.addFusedInput(aggHop.getInput().get(0).getHopID()); + } aggInfos.add(tmp); } @@ -319,10 +327,9 @@ public class PlanSelectionFuseCostBased extends PlanSelection LOG.trace(info); } - //filter aggregations w/ matmults to ensure consistent dims //sort aggregations by num dependencies to simplify merging //clusters of aggregations with parallel dependencies - aggInfos = aggInfos.stream().filter(a -> !a.containsMatMult) + aggInfos = aggInfos.stream() .sorted(Comparator.comparing(a -> a._inputAggs.size())) .collect(Collectors.toList()); @@ -366,6 +373,13 @@ public class PlanSelectionFuseCostBased extends PlanSelection } } + private static boolean isMultiAggregateRoot(Hop root) { + return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) + && ((AggUnaryOp)root).getDirection()==Direction.RowCol) + || (root instanceof AggBinaryOp && root.getDim1()==1 && root.getDim2()==1 + && HopRewriteUtils.isTransposeOperation(root.getInput().get(0))); + } + private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) { //ensure input consistent sizes (otherwise potential for incorrect results) boolean ret = true; @@ -402,11 +416,8 @@ public class PlanSelectionFuseCostBased extends PlanSelection return; //collect all applicable full aggregations per read - if( HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) - && ((AggUnaryOp)current).getDirection()==Direction.RowCol ) - { + if( isMultiAggregateRoot(current) ) aggs.add(current.getHopID()); - } //recursively process children for( Hop c : current.getInput() ) @@ -417,19 +428,12 @@ public class PlanSelectionFuseCostBased extends PlanSelection private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo, TemplateType type) { //collect input aggregates (dependents) - if( HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) - && ((AggUnaryOp)current).getDirection()==Direction.RowCol ) - { + if( isMultiAggregateRoot(current) ) aggInfo.addInputAggregate(current.getHopID()); - } - - //collect included matrix multiplications - if( type != null && HopRewriteUtils.isMatrixMultiply(current) ) - aggInfo.setContainsMatMult(); //recursively process children MemoTableEntry me = (type!=null) ? memo.getBest(current.getHopID()) : null; - for( int i=0; i< current.getInput().size(); i++ ) { + for( int i=0; i<current.getInput().size(); i++ ) { Hop c = current.getInput().get(i); if( me != null && me.isPlanRef(i) ) rExtractAggregateInfo(memo, c, aggInfo, type); @@ -960,7 +964,6 @@ public class PlanSelectionFuseCostBased extends PlanSelection public final HashMap<Long,Hop> _aggregates; public final HashSet<Long> _inputAggs = new HashSet<Long>(); public final HashSet<Long> _fusedInputs = new HashSet<Long>(); - public boolean containsMatMult = false; public AggregateInfo(Hop aggregate) { _aggregates = new HashMap<Long, Hop>(); _aggregates.put(aggregate.getHopID(), aggregate); @@ -971,9 +974,6 @@ public class PlanSelectionFuseCostBased extends PlanSelection public void addFusedInput(long hopID) { _fusedInputs.add(hopID); } - public void setContainsMatMult() { - containsMatMult = true; - } public boolean isMergable(AggregateInfo that) { //check independence boolean ret = _aggregates.size()<3 @@ -986,9 +986,11 @@ public class PlanSelectionFuseCostBased extends PlanSelection ret &= !CollectionUtils.intersection( _fusedInputs, that._fusedInputs).isEmpty(); //check consistent sizes (result correctness) + Hop in1 = _aggregates.values().iterator().next(); + Hop in2 = that._aggregates.values().iterator().next(); return ret && HopRewriteUtils.isEqualSize( - _aggregates.values().iterator().next().getInput().get(0), - that._aggregates.values().iterator().next().getInput().get(0)); + in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0), + in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0)); } public AggregateInfo merge(AggregateInfo that) { _aggregates.putAll(that._aggregates); http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index 91c61c2..5455775 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -76,6 +76,7 @@ public class TemplateCell extends TemplateBase @Override public boolean open(Hop hop) { return hop.dimsKnown() && isValidOperation(hop) + && !(hop.getDim1()==1 && hop.getDim2()==1) || (hop instanceof IndexingOp && (((IndexingOp)hop) .isColLowerEqualsUpper() || hop.getDim2()==1)); } @@ -162,6 +163,8 @@ public class TemplateCell extends TemplateBase if( me!=null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type!=TemplateType.MultiAggTpl || memo.contains(c.getHopID(), TemplateType.CellTpl))) rConstructCplan(c, memo, tmp, inHops, compileLiterals); + else if( me!=null && me.type==TemplateType.MultiAggTpl && HopRewriteUtils.isMatrixMultiply(hop) && i==0 ) + rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals); else { CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals); tmp.put(c.getHopID(), cdata); @@ -233,7 +236,10 @@ public class TemplateCell extends TemplateBase } else if( HopRewriteUtils.isTransposeOperation(hop) ) { - out = tmp.get(hop.getInput().get(0).getHopID()); + out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), + hop, tmp, compileLiterals); + if( out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)) ) + inHops.add(hop.getInput().get(0)); } else if( hop instanceof AggUnaryOp ) { @@ -246,11 +252,15 @@ public class TemplateCell extends TemplateBase //(1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y) if( HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1)) ) { CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID()); + if( TemplateUtils.isColVector(cdata1) ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); out = new CNodeUnary(cdata1, UnaryType.POW2); } else { CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals); + if( cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0)) ) + inHops.add(hop.getInput().get(0).getInput().get(0)); if( TemplateUtils.isColVector(cdata1) ) cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index bec7b38..cf6081b 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -532,13 +532,18 @@ public class HopRewriteUtils } public static Hop createScalarIndexing(Hop input, long rix, long cix) { + Hop ix = createMatrixIndexing(input, rix, cix); + return createUnary(ix, OpOp1.CAST_AS_SCALAR); + } + + public static Hop createMatrixIndexing(Hop input, long rix, long cix) { LiteralOp row = new LiteralOp(rix); LiteralOp col = new LiteralOp(cix); IndexingOp ix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input, row, row, col, col, true, true); ix.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); copyLineNumbers(input, ix); ix.refreshSizeInformation(); - return createUnary(ix, OpOp1.CAST_AS_SCALAR); + return ix; } public static Hop createValueHop( Hop hop, boolean row ) http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java index 9ca9418..9f527d4 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java @@ -218,7 +218,7 @@ public class SparseBlockCOO extends SparseBlock @Override public int size(int r) { int pos = pos(r); - if( _rindexes[pos]!=r ) + if( pos>=_size || _rindexes[pos]!=r ) return 0; //count number of equal row indexes http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java index c33d680..07a4396 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java @@ -42,6 +42,7 @@ public class MultiAggTmplTest extends AutomatedTestBase private static final String TEST_NAME4 = TEST_NAME+"4"; //sum(X*Y), sum(X^2), sum(Y^2) private static final String TEST_NAME5 = TEST_NAME+"5"; //sum(V*X), sum(Y*Z), sum(X+Y-Z) private static final String TEST_NAME6 = TEST_NAME+"6"; //min(X), max(X), sum(X) + private static final String TEST_NAME7 = TEST_NAME+"7"; //t(X)%*%X, t(X)%*Y, t(Y)%*%Y private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + MultiAggTmplTest.class.getSimpleName() + "/"; @@ -53,7 +54,7 @@ public class MultiAggTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=6; i++) + for(int i=1; i<=7; i++) addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } @@ -147,6 +148,21 @@ public class MultiAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME6, false, ExecType.SPARK ); } + @Test + public void testCodegenMultiAggRewrite7CP() { + testCodegenIntegration( TEST_NAME7, true, ExecType.CP ); + } + + @Test + public void testCodegenMultiAgg7CP() { + testCodegenIntegration( TEST_NAME7, false, ExecType.CP ); + } + + @Test + public void testCodegenMultiAgg7Spark() { + testCodegenIntegration( TEST_NAME7, false, ExecType.SPARK ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/test/scripts/functions/codegen/multiAggPattern7.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/multiAggPattern7.R b/src/test/scripts/functions/codegen/multiAggPattern7.R new file mode 100644 index 0000000..b56f090 --- /dev/null +++ b/src/test/scripts/functions/codegen/multiAggPattern7.R @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = seq(1,15); +Y = seq(2,16); + +r1 = t(X)%*%X; +r2 = t(X)%*%Y; +r3 = t(Y)%*%Y; +S = r1+r2+r3; + +writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/test/scripts/functions/codegen/multiAggPattern7.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/multiAggPattern7.dml b/src/test/scripts/functions/codegen/multiAggPattern7.dml new file mode 100644 index 0000000..3306fd8 --- /dev/null +++ b/src/test/scripts/functions/codegen/multiAggPattern7.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = seq(1,15); +Y = seq(2,16); +if(1==1){} + +r1 = t(X)%*%X; +r2 = t(X)%*%Y; +r3 = t(Y)%*%Y; +S = r1+r2+r3; + +write(S,$1)
