Add new `wumm` pattern to pick up element-wise multiply rewrite. The new pattern recognizes when there is a `*2` or `2*` outside `W*(U%*%t(V))`.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/479b9da4 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/479b9da4 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/479b9da4 Branch: refs/heads/master Commit: 479b9da4e6c605871a914ccb4b06ab6da5de21ed Parents: e93c487 Author: Dylan Hutchison <[email protected]> Authored: Thu Jul 13 01:14:48 2017 -0700 Committer: Dylan Hutchison <[email protected]> Committed: Thu Jul 13 01:14:48 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/ProgramRewriter.java | 2 +- .../RewriteAlgebraicSimplificationDynamic.java | 44 +++++++++++++++++++- 2 files changed, 43 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 59565df..7c4f861 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -54,7 +54,7 @@ public class ProgramRewriter private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName()); //internal local debug level - private static final boolean LDEBUG = false; + private static final boolean LDEBUG = false; private static final boolean CHECK = false; private ArrayList<HopRewriteRule> _dagRuleSet = null; http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 8cd71f4..6246270 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -29,11 +29,11 @@ import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.DataGenOp; import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.QuaternaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.DataGenMethod; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp1; +import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.Hop.OpOp3; import org.apache.sysml.hops.Hop.OpOp4; import org.apache.sysml.hops.Hop.ParamBuiltinOp; @@ -44,7 +44,7 @@ import org.apache.sysml.hops.LeftIndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.ParameterizedBuiltinOp; -import org.apache.sysml.hops.Hop.OpOp2; +import org.apache.sysml.hops.QuaternaryOp; import org.apache.sysml.hops.ReorgOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; @@ -1959,6 +1959,46 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule appliedPattern = true; LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")"); } + + //Pattern 1.5) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V)) + if( !appliedPattern + && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT) + && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2) + || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2))) + { + final Hop nl; // non-literal + if( hi.getInput().get(0) instanceof LiteralOp ) { + nl = hi.getInput().get(1); + } else { + nl = hi.getInput().get(0); + } + + if ( HopRewriteUtils.isBinary(nl, OpOp2.MULT) + && HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) //prevent mv + && nl.getDim2() > 1 //not applied for vector-vector mult + && nl.getInput().get(0).getDataType() == DataType.MATRIX + && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock() + && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1)) + && (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain + && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) ) + { + final Hop W = nl.getInput().get(0); + final Hop U = nl.getInput().get(1).getInput().get(0); + Hop V = nl.getInput().get(1).getInput().get(1); + if( !HopRewriteUtils.isTransposeOperation(V) ) + V = HopRewriteUtils.createTranspose(V); + else + V = V.getInput().get(0); + + hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, + OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT); + hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock()); + hnew.refreshSizeInformation(); + + appliedPattern = true; + LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line "+hi.getBeginLine()+")"); + } + } //Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops if( !appliedPattern
