Repository: systemml
Updated Branches:
  refs/heads/master 85e3a9631 -> eca9dbbb8 (forced update)


[SYSTEMML-1663] Fix and enable rewrite element-wise multiply chains

Groups together types of e-wise multiply inputs.

Comprehensive test on all different kinds of objects multiplied together.
The new order of element-wise multiply chains is as follows:

<pre>
    (((unknown * object * frame) * ([least-nnz-matrix * matrix] * 
most-nnz-matrix))
     * ([least-nnz-row-vector * row-vector] * most-nnz-row-vector))
    * ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector)
</pre>

Moves to dynamic rewrites. Do not rewrite if top-level dims unknown.

Adds new 'wumm' pattern to pick up element-wise multiply rewrite.
The new pattern recognizes when there is a '*2' or '2*' outside 'W*(U%*%t(V))'.

Closes #567.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/eca9dbbb
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/eca9dbbb
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/eca9dbbb

Branch: refs/heads/master
Commit: eca9dbbb85971af688e81c9254538c53fc429b30
Parents: 1b3dff0
Author: Dylan Hutchison <dhutc...@cs.washington.edu>
Authored: Fri Jul 14 23:08:46 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Jul 14 23:08:46 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/ProgramRewriter.java     |   8 +-
 .../RewriteAlgebraicSimplificationDynamic.java  |  45 ++++-
 ...RewriteElementwiseMultChainOptimization.java | 180 +++++++++++++------
 ...ElementwiseMultChainOptimizationAllTest.java | 134 ++++++++++++++
 .../functions/misc/RewriteEMultChainOpAll.R     |  37 ++++
 .../functions/misc/RewriteEMultChainOpAll.dml   |  31 ++++
 6 files changed, 376 insertions(+), 59 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 92d31c2..7c4f861 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -54,7 +54,7 @@ public class ProgramRewriter
        private static final Log LOG = 
LogFactory.getLog(ProgramRewriter.class.getName());
        
        //internal local debug level
-       private static final boolean LDEBUG = false; 
+       private static final boolean LDEBUG = false;
        private static final boolean CHECK = false;
        
        private ArrayList<HopRewriteRule> _dagRuleSet = null;
@@ -96,8 +96,6 @@ public class ProgramRewriter
                        _dagRuleSet.add(     new 
RewriteRemoveUnnecessaryCasts()             );         
                        if( 
OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
                                _dagRuleSet.add( new 
RewriteCommonSubexpressionElimination()     );
-                       //if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
-                       //      _dagRuleSet.add( new 
RewriteElementwiseMultChainOptimization()   ); //dependency: cse
                        if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
                                _dagRuleSet.add( new RewriteConstantFolding()   
                 ); //dependency: cse
                        if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
@@ -125,7 +123,9 @@ public class ProgramRewriter
                // DYNAMIC REWRITES (which do require size information)
                if( dynamicRewrites )
                {
-                       _dagRuleSet.add(     new 
RewriteMatrixMultChainOptimization()         ); //dependency: cse 
+                       _dagRuleSet.add(     new 
RewriteMatrixMultChainOptimization()         ); //dependency: cse
+                       if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
+                               _dagRuleSet.add( new 
RewriteElementwiseMultChainOptimization()    ); //dependency: cse
                        
                        if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
                        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 8cd71f4..09b66de 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -29,11 +29,11 @@ import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.DataGenOp;
 import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.QuaternaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.DataGenMethod;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.OpOp4;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
@@ -44,7 +44,7 @@ import org.apache.sysml.hops.LeftIndexingOp;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
-import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.QuaternaryOp;
 import org.apache.sysml.hops.ReorgOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.UnaryOp;
@@ -1959,6 +1959,47 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        appliedPattern = true;
                        LOG.debug("Applied simplifyWeightedUnaryMM1 (line 
"+hi.getBeginLine()+")");     
                }
+
+               //Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
+               if( !appliedPattern
+                               && hi instanceof BinaryOp && 
HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
+                               && 
(HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2)
+                                       || 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2)))
+               {
+                       final Hop nl; // non-literal
+                       if( hi.getInput().get(0) instanceof LiteralOp ) {
+                               nl = hi.getInput().get(1);
+                       } else {
+                               nl = hi.getInput().get(0);
+                       }
+
+                       if (       HopRewriteUtils.isBinary(nl, OpOp2.MULT)
+                                       && nl.getParent().size()==1 // ensure 
no foreign parents
+                                       && 
HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) 
//prevent mv
+                                       && nl.getDim2() > 1 //not applied for 
vector-vector mult
+                                       && nl.getInput().get(0).getDataType() 
== DataType.MATRIX
+                                       && nl.getInput().get(0).getDim2() > 
nl.getInput().get(0).getColsInBlock()
+                                       && 
HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1))
+                                       && (((AggBinaryOp) 
nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || 
nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
+                                       && 
HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) )
+                       {
+                               final Hop W = nl.getInput().get(0);
+                               final Hop U = 
nl.getInput().get(1).getInput().get(0);
+                               Hop V = nl.getInput().get(1).getInput().get(1);
+                               if( !HopRewriteUtils.isTransposeOperation(V) )
+                                       V = HopRewriteUtils.createTranspose(V);
+                               else
+                                       V = V.getInput().get(0);
+
+                               hnew = new QuaternaryOp(hi.getName(), 
DataType.MATRIX, ValueType.DOUBLE,
+                                               OpOp4.WUMM, W, U, V, true, 
null, OpOp2.MULT);
+                               hnew.setOutputBlocksizes(W.getRowsInBlock(), 
W.getColsInBlock());
+                               hnew.refreshSizeInformation();
+
+                               appliedPattern = true;
+                               LOG.debug("Applied simplifyWeightedUnaryMM2.7 
(line "+hi.getBeginLine()+")");
+                       }
+               }
                
                //Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to 
unary ops
                if( !appliedPattern

http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index c2c3b11..2e411f6 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -26,8 +26,8 @@ import java.util.HashSet;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
+import java.util.SortedSet;
+import java.util.TreeSet;
 
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
@@ -44,6 +44,15 @@ import org.apache.sysml.parser.Expression;
  *
  * Does not rewrite in the presence of foreign parents in the middle of the 
e-wise multiply chain,
  * since foreign parents may rely on the individual results.
+ * Does not perform rewrites on an element-wise multiply if its dimensions are 
unknown.
+ *
+ * The new order of element-wise multiply chains is as follows:
+ * <pre>
+ *     (((unknown * object * frame) * ([least-nnz-matrix * matrix] * 
most-nnz-matrix))
+ *      * ([least-nnz-row-vector * row-vector] * most-nnz-row-vector))
+ *     * ([[scalars * least-nnz-col-vector] * col-vector] * 
most-nnz-col-vector)
+ * </pre>
+ * Identical elements are replaced with powers.
  */
 public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
        @Override
@@ -73,15 +82,18 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                        return root;
                root.setVisited();
 
-               // 1. Find immediate subtree of EMults.
-               if (isBinaryMult(root)) {
+               // 1. Find immediate subtree of EMults. Check dimsKnown.
+               if (isBinaryMult(root) && root.dimsKnown()) {
                        final Hop left = root.getInput().get(0), right = 
root.getInput().get(1);
+                       // The set of BinaryOp element-wise multiply hops in 
the emult chain.
                        final Set<BinaryOp> emults = new HashSet<>();
+                       // The multiset of multiplicands in the emult chain.
                        final Map<Hop, Integer> leaves = new HashMap<>(); // 
poor man's HashMultiset
                        findEMultsAndLeaves((BinaryOp)root, emults, leaves);
 
                        // 2. Ensure it is profitable to do a rewrite.
-                       if (isOptimizable(emults, leaves)) {
+                       // Only optimize a subtree of emults if there are at 
least two emults (meaning, at least 3 multiplicands).
+                       if (emults.size() >= 2) {
                                // 3. Check for foreign parents.
                                // A foreign parent is a parent of some EMult 
that is not in the set.
                                // Foreign parents destroy correctness of this 
rewrite.
@@ -123,38 +135,110 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
 
        private static Hop constructReplacement(final Set<BinaryOp> emults, 
final Map<Hop, Integer> leaves) {
                // Sort by data type
-               final SortedMap<Hop,Integer> sorted = new 
TreeMap<>(compareByDataType);
+               final SortedSet<Hop> sorted = new TreeSet<>(compareByDataType);
                for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
                        final Hop h = entry.getKey();
                        // unlink parents that are in the emult set(we are 
throwing them away)
                        // keep other parents
                        h.getParent().removeIf(parent -> parent instanceof 
BinaryOp && emults.contains(parent));
-                       sorted.put(h, entry.getValue());
+                       sorted.add(constructPower(h, entry.getValue()));
                }
                // sorted contains all leaves, sorted by data type, stripped 
from their parents
 
                // Construct right-deep EMult tree
-               // TODO compile binary outer mult for transition from row and 
column vectors to matrices
-               // TODO compile subtree for column vectors to avoid blow-up of 
intermediates on row-col vector transition
-               final Iterator<Map.Entry<Hop, Integer>> iterator = 
sorted.entrySet().iterator();
-               Hop first = constructPower(iterator.next());
-
-               for (int i = 1; i < sorted.size(); i++) {
-                       final Hop second = constructPower(iterator.next());
-                       first = HopRewriteUtils.createBinary(second, first, 
Hop.OpOp2.MULT);
-                       first.setVisited();
+               final Iterator<Hop> iterator = sorted.iterator();
+
+               Hop next = iterator.hasNext() ? iterator.next() : null;
+               Hop colVectorsScalars = null;
+               while(next != null &&
+                               (next.getDataType() == 
Expression.DataType.SCALAR
+                                               || next.getDataType() == 
Expression.DataType.MATRIX && next.getDim2() == 1))
+               {
+                       if( colVectorsScalars == null )
+                               colVectorsScalars = next;
+                       else {
+                               colVectorsScalars = 
HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
+                               colVectorsScalars.setVisited();
+                       }
+                       next = iterator.hasNext() ? iterator.next() : null;
+               }
+               // next is not processed and is either null or past col vectors
+
+               Hop rowVectors = null;
+               while(next != null &&
+                               (next.getDataType() == 
Expression.DataType.MATRIX && next.getDim1() == 1))
+               {
+                       if( rowVectors == null )
+                               rowVectors = next;
+                       else {
+                               rowVectors = 
HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
+                               rowVectors.setVisited();
+                       }
+                       next = iterator.hasNext() ? iterator.next() : null;
+               }
+               // next is not processed and is either null or past row vectors
+
+               Hop matrices = null;
+               while(next != null &&
+                               (next.getDataType() == 
Expression.DataType.MATRIX))
+               {
+                       if( matrices == null )
+                               matrices = next;
+                       else {
+                               matrices = 
HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
+                               matrices.setVisited();
+                       }
+                       next = iterator.hasNext() ? iterator.next() : null;
+               }
+               // next is not processed and is either null or past matrices
+
+               Hop other = null;
+               while(next != null)
+               {
+                       if( other == null )
+                               other = next;
+                       else {
+                               other = HopRewriteUtils.createBinary(other, 
next, Hop.OpOp2.MULT);
+                               other.setVisited();
+                       }
+                       next = iterator.hasNext() ? iterator.next() : null;
+               }
+               // finished
+
+               // ((other * matrices) * rowVectors) * colVectorsScalars
+               Hop top = null;
+               if( other == null && matrices != null )
+                       top = matrices;
+               else if( other != null && matrices == null )
+                       top = other;
+               else if( other != null ) { //matrices != null
+                       top = HopRewriteUtils.createBinary(other, matrices, 
Hop.OpOp2.MULT);
+                       top.setVisited();
+               }
+
+               if( top == null && rowVectors != null )
+                       top = rowVectors;
+               else if( rowVectors != null ) { //top != null
+                       top = HopRewriteUtils.createBinary(top, rowVectors, 
Hop.OpOp2.MULT);
+                       top.setVisited();
+               }
+
+               if( top == null && colVectorsScalars != null )
+                       top = colVectorsScalars;
+               else if( colVectorsScalars != null ) { //top != null
+                       top = HopRewriteUtils.createBinary(top, 
colVectorsScalars, Hop.OpOp2.MULT);
+                       top.setVisited();
                }
-               return first;
+
+               return top;
        }
 
-       private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
-               final Hop hop = entry.getKey();
-               final int cnt = entry.getValue();
+       private static Hop constructPower(final Hop hop, final int cnt) {
                assert(cnt >= 1);
                hop.setVisited(); // we will visit the leaves' children next
                if (cnt == 1)
                        return hop;
-               Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), 
Hop.OpOp2.POW);
+               final Hop pow = HopRewriteUtils.createBinary(hop, new 
LiteralOp(cnt), Hop.OpOp2.POW);
                pow.setVisited();
                return pow;
        }
@@ -162,8 +246,8 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
        /**
         * A Comparator that orders Hops by their data type, dimension, and 
sparsity.
         * The order is as follows:
-        *              scalars > row vectors > col vectors >
-        *      non-vector matrices ordered by sparsity (higher nnz first, 
unknown sparsity last) >
+        *              scalars < col vectors < row vectors <
+        *      non-vector matrices ordered by sparsity (higher nnz last, 
unknown sparsity last) >
         *      other data types.
         * Disambiguate by Hop ID.
         */
@@ -174,33 +258,33 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                {
                        for (int i = 0, valuesLength = 
Expression.DataType.values().length; i < valuesLength; i++)
                                switch(Expression.DataType.values()[i]) {
-                               case SCALAR: orderDataType[i] = 4; break;
-                               case MATRIX: orderDataType[i] = 3; break;
+                               case SCALAR: orderDataType[i] = 0; break;
+                               case MATRIX: orderDataType[i] = 1; break;
                                case FRAME:  orderDataType[i] = 2; break;
-                               case OBJECT: orderDataType[i] = 1; break;
-                               case UNKNOWN:orderDataType[i] = 0; break;
+                               case OBJECT: orderDataType[i] = 3; break;
+                               case UNKNOWN:orderDataType[i] = 4; break;
                                }
                }
 
                @Override
-               public final int compare(Hop o1, Hop o2) {
-                       int c = 
Integer.compare(orderDataType[o1.getDataType().ordinal()], 
orderDataType[o2.getDataType().ordinal()]);
+               public final int compare(final Hop o1, final Hop o2) {
+                       final int c = 
Integer.compare(orderDataType[o1.getDataType().ordinal()], 
orderDataType[o2.getDataType().ordinal()]);
                        if (c != 0) return c;
 
                        // o1 and o2 have the same data type
                        switch (o1.getDataType()) {
                        case MATRIX:
                                // two matrices; check for vectors
-                               if (o1.getDim1() == 1) { // row vector
-                                       if (o2.getDim1() != 1) return 1; // row 
vectors are greatest of matrices
-                                       return compareBySparsityThenId(o1, o2); 
// both row vectors
-                               } else if (o2.getDim1() == 1) { // 2 is row 
vector; 1 is not
-                                       return -1; // row vectors are the 
greatest matrices
-                               } else if (o1.getDim2() == 1) { // col vector
-                                       if (o2.getDim2() != 1) return 1; // col 
vectors greater than non-vectors
-                                       return compareBySparsityThenId(o1, o2); 
// both col vectors
+                               if (o1.getDim2() == 1) { // col vector
+                                               if (o2.getDim2() != 1) return 
-1; // col vectors are greatest of matrices
+                                               return 
compareBySparsityThenId(o1, o2); // both col vectors
                                } else if (o2.getDim2() == 1) { // 2 is col 
vector; 1 is not
-                                       return -1; // col vectors greater than 
non-vectors
+                                               return 1; // col vectors are 
the greatest matrices
+                               } else if (o1.getDim1() == 1) { // row vector
+                                               if (o2.getDim1() != 1) return 
-1; // row vectors greater than non-vectors
+                                               return 
compareBySparsityThenId(o1, o2); // both row vectors
+                               } else if (o2.getDim1() == 1) { // 2 is row 
vector; 1 is not
+                                               return 1; // row vectors 
greater than non-vectors
                                } else { // both non-vectors
                                        return compareBySparsityThenId(o1, o2);
                                }
@@ -208,13 +292,13 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                                return Long.compare(o1.getHopID(), 
o2.getHopID());
                        }
                }
-               private int compareBySparsityThenId(Hop o1, Hop o2) {
+               private int compareBySparsityThenId(final Hop o1, final Hop o2) 
{
                        // the hop with more nnz is first; unknown nnz (-1) last
-                       int c = Long.compare(o1.getNnz(), o2.getNnz());
-                       if (c != 0) return c;
+                       final int c = Long.compare(o1.getNnz(), o2.getNnz());
+                       if (c != 0) return -c;
                        return Long.compare(o1.getHopID(), o2.getHopID());
                }
-       }.reversed();
+       };
 
        /**
         * Check if a node has a parent that is not in the set of emults. 
Recursively check children who are also emults.
@@ -242,8 +326,7 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
         * @param emults Out parameter. The set of BinaryOp element-wise 
multiply hops in the emult chain (including root).
         * @param leaves Out parameter. The multiset of multiplicands in the 
emult chain.
         */
-       private static void findEMultsAndLeaves(final BinaryOp root, final 
Set<BinaryOp> emults,
-                       final Map<Hop, Integer> leaves) {
+       private static void findEMultsAndLeaves(final BinaryOp root, final 
Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
                // Because RewriteCommonSubexpressionElimination already ran, 
it is safe to compare by equality.
                emults.add(root);
                
@@ -268,13 +351,4 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                map.put(k, map.getOrDefault(k, 0) + 1);
        }
 
-       /**
-        * Only optimize a subtree of emults if there are at least two emults 
(meaning, at least 3 multiplicands).
-        * @param emults The set of BinaryOp element-wise multiply hops in the 
emult chain.
-        * @param leaves The multiset of multiplicands in the emult chain.
-        * @return If the multiset is worth optimizing.
-        */
-       private static boolean isOptimizable(final Set<BinaryOp> emults, final 
Map<Hop, Integer> leaves) {
-               return emults.size() >= 2;
-       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
new file mode 100644
index 0000000..ba5c78d
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test rewriting `2*X*3*v*5*w*4*z*5*Y*2*v*2*X`, where `v` and `z` are row 
vectors and `w` is a column vector,
+ * successfully rewrites to `Y*(X^2)*z*(v^2)*w*2400`.
+ */
+public class RewriteElementwiseMultChainOptimizationAllTest extends 
AutomatedTestBase
+{
+       private static final String TEST_NAME1 = "RewriteEMultChainOpAll";
+       private static final String TEST_DIR = "functions/misc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteElementwiseMultChainOptimizationAllTest.class.getSimpleName() + "/";
+       
+       private static final int rows = 123;
+       private static final int cols = 321;
+       private static final double eps = Math.pow(10, -10);
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+       }
+
+       @Test
+       public void testMatrixMultChainOptNoRewritesCP() {
+               testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testMatrixMultChainOptNoRewritesSP() {
+               testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testMatrixMultChainOptRewritesCP() {
+               testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testMatrixMultChainOptRewritesSP() {
+               testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+       }
+
+       private void testRewriteMatrixMultChainOp(String testname, boolean 
rewrites, ExecType et)
+       {       
+               RUNTIME_PLATFORM platformOld = rtplatform;
+               switch( et ){
+                       case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+                       case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; 
break;
+               }
+               
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if( rtplatform == RUNTIME_PLATFORM.SPARK )
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               
+               boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+               OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+               
+               try
+               {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[] { "-explain", "hops", 
"-stats", "-args", input("X"), input("Y"), input("v"), input("z"), input("w"), 
output("R") };
+                       fullRScriptName = HOME + testname + ".R";
+                       rCmd = getRCmd(inputDir(), expectedDir());
+
+                       double Xsparsity = 0.8, Ysparsity = 0.6;
+                       double[][] X = getRandomMatrix(rows, cols, -1, 1, 
Xsparsity, 7);
+                       double[][] Y = getRandomMatrix(rows, cols, -1, 1, 
Ysparsity, 3);
+                       double[][] z = getRandomMatrix(1, cols, -1, 1, 
Ysparsity, 5);
+                       double[][] v = getRandomMatrix(1, cols, -1, 1, 
Xsparsity, 4);
+                       double[][] w = getRandomMatrix(rows, 1, -1, 1, 
Ysparsity, 6);
+                       writeInputMatrixWithMTD("X", X, true);
+                       writeInputMatrixWithMTD("Y", Y, true);
+                       writeInputMatrixWithMTD("z", z, true);
+                       writeInputMatrixWithMTD("v", v, true);
+                       writeInputMatrixWithMTD("w", w, true);
+
+                       //execute tests
+                       runTest(true, false, null, -1); 
+                       runRScript(true); 
+                       
+                       //compare matrices 
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("R");
+                       HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("R");
+                       TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+                       
+                       //check for presence of power operator, if we did a 
rewrite
+                       if( rewrites ) {
+                               
Assert.assertTrue(heavyHittersContainsSubString("^2"));
+                       }
+               }
+               finally {
+                       OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R 
b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
new file mode 100644
index 0000000..20f76c2
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+# args[1]=""
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+v = as.matrix(readMM(paste(args[1], "v.mtx", sep="")))
+z = as.matrix(readMM(paste(args[1], "z.mtx", sep="")))
+w = as.matrix(readMM(paste(args[1], "w.mtx", sep="")))
+
+R = 2* X *3* X *5* Y *4*5*2*2* (matrix(1,length(w),1)%*%z) * 
(matrix(1,length(w),1)%*%v)^2 * (w%*%matrix(1,1,length(v)))
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));

http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml 
b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
new file mode 100644
index 0000000..90f9242
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+v = read($3);
+z = read($4);
+w = read($5);
+
+R = 2* X *3* v *5* w *4* z *5* Y *2* v *2* X
+
+write(R, $6);
\ No newline at end of file

Reply via email to