[SYSTEMML-694] Improved wdivmm rewrites (outer-product-like mm only) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/382df653 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/382df653 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/382df653
Branch: refs/heads/master Commit: 382df653a5438124b1cdff8c676a55b9c0d50976 Parents: a528f5e Author: Matthias Boehm <[email protected]> Authored: Fri Jul 29 23:25:01 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jul 30 16:23:21 2016 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/HopRewriteUtils.java | 11 +++++++++ .../RewriteAlgebraicSimplificationDynamic.java | 24 +++++++++++--------- 2 files changed, 24 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/382df653/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 385a888..f7e4656 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -785,6 +785,17 @@ public class HopRewriteUtils : (hop.getDim1()>0 && hop.getDim1()<=hop.getRowsInBlock()); } + /** + * + * @param hop + * @return + */ + public static boolean isOuterProductLikeMM( Hop hop ) { + return hop instanceof AggBinaryOp + && hop.getInput().get(0).getDim1() > hop.getInput().get(0).getDim2() + && hop.getInput().get(1).getDim1() < hop.getInput().get(1).getDim2(); + } + public static boolean isEqualValue( LiteralOp hop1, LiteralOp hop2 ) throws HopsException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/382df653/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index dbde506..10953f5 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -1852,7 +1852,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule //alternative pattern: t(U) %*% (W*(U%*%t(V))) if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY) && HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv - && right.getInput().get(1) instanceof AggBinaryOp + && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1)) && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = right.getInput().get(0); @@ -1886,6 +1886,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule && right.getInput().get(1) instanceof BinaryOp && ((BinaryOp) right.getInput().get(1)).getOp() == Hop.OpOp2.PLUS && right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR + && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = right.getInput().get(0); @@ -1917,7 +1918,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule if( !appliedPattern && left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY) && HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv - && left.getInput().get(1) instanceof AggBinaryOp + && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1)) && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = left.getInput().get(0); @@ -1948,6 +1949,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule && left.getInput().get(1) instanceof BinaryOp && ((BinaryOp) left.getInput().get(1)).getOp() == Hop.OpOp2.PLUS && left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR + && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = left.getInput().get(0); @@ -1975,8 +1977,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule if( !appliedPattern && right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT && right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS - && right.getInput().get(1).getInput().get(0) instanceof AggBinaryOp - && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX + && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) + && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = right.getInput().get(0); @@ -2008,8 +2010,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule if( !appliedPattern && left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT && left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS - && left.getInput().get(1).getInput().get(0) instanceof AggBinaryOp - && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX + && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) + && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = left.getInput().get(0); @@ -2038,8 +2040,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule if( !appliedPattern && right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT && right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS - && right.getInput().get(1).getInput().get(0) instanceof AggBinaryOp - && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX + && HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0)) + && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = right.getInput().get(0); @@ -2071,8 +2073,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule if( !appliedPattern && left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT && left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS - && left.getInput().get(1).getInput().get(0) instanceof AggBinaryOp - && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX + && HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0)) + && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT { Hop W = left.getInput().get(0); @@ -2105,7 +2107,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule && hi.getDim2() > 1 //not applied for vector-vector mult && hi.getInput().get(0).getDataType() == DataType.MATRIX && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock() - && hi.getInput().get(1) instanceof AggBinaryOp + && HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1)) && (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT {
