[SYSTEMML-1254] New sum-product rewrites (agg pushdown), stratstats In the spirit of our SPOOF compiler framework and the existing sum(X%*%Y) rewrite, this patch adds the following two sum-product rewrites (where the first applies multiple times in stratstats):
* colSums(X %*% Y) -> colsSums(X) %*% Y * rowSums(X %*% Y) -> X %*% rowSums(Y) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b3ba9916 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b3ba9916 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b3ba9916 Branch: refs/heads/master Commit: b3ba991604cf79c5e3e2c0992fe2439ae47ce023 Parents: d3e617b Author: Matthias Boehm <[email protected]> Authored: Sun Feb 12 08:29:36 2017 +0100 Committer: Matthias Boehm <[email protected]> Committed: Sun Feb 12 09:42:03 2017 +0100 ---------------------------------------------------------------------- .../RewriteAlgebraicSimplificationDynamic.java | 54 +++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b3ba9916/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index e8e3862..6ffcbd5 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -2547,11 +2547,12 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule private Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos) { - //sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)) - //if not dot product, not applied since aggregate removed - //if sum not the only consumer, not applied to prevent redundancy + //sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product + //colSums(A%*%B) -> colSums(A)%*%B + //rowSums(A%*%B) -> A%*%rowSums(B) + //-- if not dot product, not applied since aggregate removed + //-- if sum not the only consumer, not applied to prevent redundancy if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum - && ((AggUnaryOp)hi).getDirection() == Direction.RowCol //full aggregate && hi.getInput().get(0) instanceof AggBinaryOp //A%*%B && (hi.getInput().get(0).getDim1()>1 || hi.getInput().get(0).getDim2()>1) //not dot product && hi.getInput().get(0).getParent().size()==1 ) //not multiple consumers of matrix mult @@ -2560,34 +2561,39 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule Hop left = hi2.getInput().get(0); Hop right = hi2.getInput().get(1); - //remove link from parent to diag + //remove link from parent to matrix mult HopRewriteUtils.removeChildReference(hi, hi2); //create new operators - AggUnaryOp colSum = new AggUnaryOp(left.getName(), left.getDataType(), left.getValueType(), AggOp.SUM, Direction.Col, left); - colSum.setRowsInBlock(left.getRowsInBlock()); - colSum.setColsInBlock(left.getColsInBlock()); - colSum.refreshSizeInformation(); - ReorgOp trans = HopRewriteUtils.createTranspose(colSum); - AggUnaryOp rowSum = new AggUnaryOp(right.getName(), right.getDataType(), right.getValueType(), AggOp.SUM, Direction.Row, right); - rowSum.setRowsInBlock(right.getRowsInBlock()); - rowSum.setColsInBlock(right.getColsInBlock()); - rowSum.refreshSizeInformation(); - BinaryOp mult = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), OpOp2.MULT, trans, rowSum); - mult.setRowsInBlock(right.getRowsInBlock()); - mult.setColsInBlock(right.getColsInBlock()); - mult.refreshSizeInformation(); - + Hop root = null; + //pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product + if( ((AggUnaryOp)hi).getDirection() == Direction.RowCol ) { + AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col); + ReorgOp trans = HopRewriteUtils.createTranspose(colSum); + AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row); + root = HopRewriteUtils.createBinary(trans, rowSum, OpOp2.MULT); + LOG.debug("Applied simplifySumMatrixMult RC."); + } + //colSums(A%*%B) -> colSums(A)%*%B + else if( ((AggUnaryOp)hi).getDirection() == Direction.Col ) { + AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col); + root = HopRewriteUtils.createMatrixMultiply(colSum, right); + LOG.debug("Applied simplifySumMatrixMult C."); + } + //rowSums(A%*%B) -> A%*%rowSums(B) + else if( ((AggUnaryOp)hi).getDirection() == Direction.Row ) { + AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row); + root = HopRewriteUtils.createMatrixMultiply(left, rowSum); + LOG.debug("Applied simplifySumMatrixMult R."); + } //rehang new subdag under current node (keep hi intact) - HopRewriteUtils.addChildReference(hi, mult, 0); + HopRewriteUtils.addChildReference(hi, root, 0); hi.refreshSizeInformation(); - + //cleanup if only consumer of intermediate if( hi2.getParent().isEmpty() ) - HopRewriteUtils.removeAllChildReferences( hi2 ); - - LOG.debug("Applied simplifySumMatrixMult."); + HopRewriteUtils.removeAllChildReferences( hi2 ); } return hi;
