Repository: systemml
Updated Branches:
  refs/heads/master b84a4933c -> 352c256a3


[SYSTEMML-1755] Fix simplification rewrite binary matrix-scalar ops

This patch fixes the rewrite for simplifying matrix-scalar to
scalar-scalar operations to correctly check for binary operations that
are supported over scalars. 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/352c256a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/352c256a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/352c256a

Branch: refs/heads/master
Commit: 352c256a3d71bb587162120134f87e4a9a2df507
Parents: b84a493
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sun Jul 9 00:32:47 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sun Jul 9 00:32:47 2017 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/Hop.java    | 92 ++++++++++----------
 .../RewriteAlgebraicSimplificationStatic.java   |  8 +-
 2 files changed, 54 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/352c256a/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index 8f8afde..80d33f1 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -28,6 +28,8 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.lops.Binary;
+import org.apache.sysml.lops.BinaryScalar;
 import org.apache.sysml.lops.CSVReBlock;
 import org.apache.sysml.lops.Checkpoint;
 import org.apache.sysml.lops.Compression;
@@ -1143,53 +1145,53 @@ public abstract class Hop
 
        }
 
-       protected static final HashMap<Hop.OpOp2, 
org.apache.sysml.lops.Binary.OperationTypes> HopsOpOp2LopsB;
+       protected static final HashMap<Hop.OpOp2, Binary.OperationTypes> 
HopsOpOp2LopsB;
        static {
-               HopsOpOp2LopsB = new HashMap<Hop.OpOp2, 
org.apache.sysml.lops.Binary.OperationTypes>();
-               HopsOpOp2LopsB.put(OpOp2.PLUS, 
org.apache.sysml.lops.Binary.OperationTypes.ADD);
-               HopsOpOp2LopsB.put(OpOp2.MINUS, 
org.apache.sysml.lops.Binary.OperationTypes.SUBTRACT);
-               HopsOpOp2LopsB.put(OpOp2.MULT, 
org.apache.sysml.lops.Binary.OperationTypes.MULTIPLY);
-               HopsOpOp2LopsB.put(OpOp2.DIV, 
org.apache.sysml.lops.Binary.OperationTypes.DIVIDE);
-               HopsOpOp2LopsB.put(OpOp2.MODULUS, 
org.apache.sysml.lops.Binary.OperationTypes.MODULUS);
-               HopsOpOp2LopsB.put(OpOp2.INTDIV, 
org.apache.sysml.lops.Binary.OperationTypes.INTDIV);
-               HopsOpOp2LopsB.put(OpOp2.MINUS1_MULT, 
org.apache.sysml.lops.Binary.OperationTypes.MINUS1_MULTIPLY);
-               HopsOpOp2LopsB.put(OpOp2.LESS, 
org.apache.sysml.lops.Binary.OperationTypes.LESS_THAN);
-               HopsOpOp2LopsB.put(OpOp2.LESSEQUAL, 
org.apache.sysml.lops.Binary.OperationTypes.LESS_THAN_OR_EQUALS);
-               HopsOpOp2LopsB.put(OpOp2.GREATER, 
org.apache.sysml.lops.Binary.OperationTypes.GREATER_THAN);
-               HopsOpOp2LopsB.put(OpOp2.GREATEREQUAL, 
org.apache.sysml.lops.Binary.OperationTypes.GREATER_THAN_OR_EQUALS);
-               HopsOpOp2LopsB.put(OpOp2.EQUAL, 
org.apache.sysml.lops.Binary.OperationTypes.EQUALS);
-               HopsOpOp2LopsB.put(OpOp2.NOTEQUAL, 
org.apache.sysml.lops.Binary.OperationTypes.NOT_EQUALS);
-               HopsOpOp2LopsB.put(OpOp2.MIN, 
org.apache.sysml.lops.Binary.OperationTypes.MIN);
-               HopsOpOp2LopsB.put(OpOp2.MAX, 
org.apache.sysml.lops.Binary.OperationTypes.MAX);
-               HopsOpOp2LopsB.put(OpOp2.AND, 
org.apache.sysml.lops.Binary.OperationTypes.OR);
-               HopsOpOp2LopsB.put(OpOp2.OR, 
org.apache.sysml.lops.Binary.OperationTypes.AND);
-               HopsOpOp2LopsB.put(OpOp2.SOLVE, 
org.apache.sysml.lops.Binary.OperationTypes.SOLVE);
-               HopsOpOp2LopsB.put(OpOp2.POW, 
org.apache.sysml.lops.Binary.OperationTypes.POW);
-               HopsOpOp2LopsB.put(OpOp2.LOG, 
org.apache.sysml.lops.Binary.OperationTypes.NOTSUPPORTED);
-       }
-
-       protected static final HashMap<Hop.OpOp2, 
org.apache.sysml.lops.BinaryScalar.OperationTypes> HopsOpOp2LopsBS;
+               HopsOpOp2LopsB = new HashMap<Hop.OpOp2, 
Binary.OperationTypes>();
+               HopsOpOp2LopsB.put(OpOp2.PLUS, Binary.OperationTypes.ADD);
+               HopsOpOp2LopsB.put(OpOp2.MINUS, Binary.OperationTypes.SUBTRACT);
+               HopsOpOp2LopsB.put(OpOp2.MULT, Binary.OperationTypes.MULTIPLY);
+               HopsOpOp2LopsB.put(OpOp2.DIV, Binary.OperationTypes.DIVIDE);
+               HopsOpOp2LopsB.put(OpOp2.MODULUS, 
Binary.OperationTypes.MODULUS);
+               HopsOpOp2LopsB.put(OpOp2.INTDIV, Binary.OperationTypes.INTDIV);
+               HopsOpOp2LopsB.put(OpOp2.MINUS1_MULT, 
Binary.OperationTypes.MINUS1_MULTIPLY);
+               HopsOpOp2LopsB.put(OpOp2.LESS, Binary.OperationTypes.LESS_THAN);
+               HopsOpOp2LopsB.put(OpOp2.LESSEQUAL, 
Binary.OperationTypes.LESS_THAN_OR_EQUALS);
+               HopsOpOp2LopsB.put(OpOp2.GREATER, 
Binary.OperationTypes.GREATER_THAN);
+               HopsOpOp2LopsB.put(OpOp2.GREATEREQUAL, 
Binary.OperationTypes.GREATER_THAN_OR_EQUALS);
+               HopsOpOp2LopsB.put(OpOp2.EQUAL, Binary.OperationTypes.EQUALS);
+               HopsOpOp2LopsB.put(OpOp2.NOTEQUAL, 
Binary.OperationTypes.NOT_EQUALS);
+               HopsOpOp2LopsB.put(OpOp2.MIN, Binary.OperationTypes.MIN);
+               HopsOpOp2LopsB.put(OpOp2.MAX, Binary.OperationTypes.MAX);
+               HopsOpOp2LopsB.put(OpOp2.AND, Binary.OperationTypes.OR);
+               HopsOpOp2LopsB.put(OpOp2.OR, Binary.OperationTypes.AND);
+               HopsOpOp2LopsB.put(OpOp2.SOLVE, Binary.OperationTypes.SOLVE);
+               HopsOpOp2LopsB.put(OpOp2.POW, Binary.OperationTypes.POW);
+               HopsOpOp2LopsB.put(OpOp2.LOG, 
Binary.OperationTypes.NOTSUPPORTED);
+       }
+
+       protected static final HashMap<Hop.OpOp2, BinaryScalar.OperationTypes> 
HopsOpOp2LopsBS;
        static {
-               HopsOpOp2LopsBS = new HashMap<Hop.OpOp2, 
org.apache.sysml.lops.BinaryScalar.OperationTypes>();
-               HopsOpOp2LopsBS.put(OpOp2.PLUS, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.ADD); 
-               HopsOpOp2LopsBS.put(OpOp2.MINUS, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.SUBTRACT);
-               HopsOpOp2LopsBS.put(OpOp2.MULT, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.MULTIPLY);
-               HopsOpOp2LopsBS.put(OpOp2.DIV, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.DIVIDE);
-               HopsOpOp2LopsBS.put(OpOp2.MODULUS, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.MODULUS);
-               HopsOpOp2LopsBS.put(OpOp2.INTDIV, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.INTDIV);
-               HopsOpOp2LopsBS.put(OpOp2.LESS, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.LESS_THAN);
-               HopsOpOp2LopsBS.put(OpOp2.LESSEQUAL, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.LESS_THAN_OR_EQUALS);
-               HopsOpOp2LopsBS.put(OpOp2.GREATER, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.GREATER_THAN);
-               HopsOpOp2LopsBS.put(OpOp2.GREATEREQUAL, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.GREATER_THAN_OR_EQUALS);
-               HopsOpOp2LopsBS.put(OpOp2.EQUAL, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.EQUALS);
-               HopsOpOp2LopsBS.put(OpOp2.NOTEQUAL, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.NOT_EQUALS);
-               HopsOpOp2LopsBS.put(OpOp2.MIN, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.MIN);
-               HopsOpOp2LopsBS.put(OpOp2.MAX, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.MAX);
-               HopsOpOp2LopsBS.put(OpOp2.AND, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.AND);
-               HopsOpOp2LopsBS.put(OpOp2.OR, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.OR);
-               HopsOpOp2LopsBS.put(OpOp2.LOG, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.LOG);
-               HopsOpOp2LopsBS.put(OpOp2.POW, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.POW);
-               HopsOpOp2LopsBS.put(OpOp2.PRINT, 
org.apache.sysml.lops.BinaryScalar.OperationTypes.PRINT);
+               HopsOpOp2LopsBS = new HashMap<Hop.OpOp2, 
BinaryScalar.OperationTypes>();
+               HopsOpOp2LopsBS.put(OpOp2.PLUS, 
BinaryScalar.OperationTypes.ADD);       
+               HopsOpOp2LopsBS.put(OpOp2.MINUS, 
BinaryScalar.OperationTypes.SUBTRACT);
+               HopsOpOp2LopsBS.put(OpOp2.MULT, 
BinaryScalar.OperationTypes.MULTIPLY);
+               HopsOpOp2LopsBS.put(OpOp2.DIV, 
BinaryScalar.OperationTypes.DIVIDE);
+               HopsOpOp2LopsBS.put(OpOp2.MODULUS, 
BinaryScalar.OperationTypes.MODULUS);
+               HopsOpOp2LopsBS.put(OpOp2.INTDIV, 
BinaryScalar.OperationTypes.INTDIV);
+               HopsOpOp2LopsBS.put(OpOp2.LESS, 
BinaryScalar.OperationTypes.LESS_THAN);
+               HopsOpOp2LopsBS.put(OpOp2.LESSEQUAL, 
BinaryScalar.OperationTypes.LESS_THAN_OR_EQUALS);
+               HopsOpOp2LopsBS.put(OpOp2.GREATER, 
BinaryScalar.OperationTypes.GREATER_THAN);
+               HopsOpOp2LopsBS.put(OpOp2.GREATEREQUAL, 
BinaryScalar.OperationTypes.GREATER_THAN_OR_EQUALS);
+               HopsOpOp2LopsBS.put(OpOp2.EQUAL, 
BinaryScalar.OperationTypes.EQUALS);
+               HopsOpOp2LopsBS.put(OpOp2.NOTEQUAL, 
BinaryScalar.OperationTypes.NOT_EQUALS);
+               HopsOpOp2LopsBS.put(OpOp2.MIN, BinaryScalar.OperationTypes.MIN);
+               HopsOpOp2LopsBS.put(OpOp2.MAX, BinaryScalar.OperationTypes.MAX);
+               HopsOpOp2LopsBS.put(OpOp2.AND, BinaryScalar.OperationTypes.AND);
+               HopsOpOp2LopsBS.put(OpOp2.OR, BinaryScalar.OperationTypes.OR);
+               HopsOpOp2LopsBS.put(OpOp2.LOG, BinaryScalar.OperationTypes.LOG);
+               HopsOpOp2LopsBS.put(OpOp2.POW, BinaryScalar.OperationTypes.POW);
+               HopsOpOp2LopsBS.put(OpOp2.PRINT, 
BinaryScalar.OperationTypes.PRINT);
        }
 
        protected static final HashMap<Hop.OpOp2, 
org.apache.sysml.lops.Unary.OperationTypes> HopsOpOp2LopsU;

http://git-wip-us.apache.org/repos/asf/systemml/blob/352c256a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index b8f9369..53359cc 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -846,8 +846,14 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
        private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, 
int pos ) 
                throws HopsException
        {
+               // Note: This rewrite is not applicable for all binary 
operations because some of them 
+               // are undefined over scalars. We explicitly exclude potential 
conflicting matrix-scalar binary
+               // operations; other operations like cbind/rbind will never 
occur as matrix-scalar operations.
+               
                if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)  
-                  && hi.getInput().get(0) instanceof BinaryOp ) 
+                       && hi.getInput().get(0) instanceof BinaryOp
+                       && !HopRewriteUtils.isBinary(hi.getInput().get(0), 
OpOp2.QUANTILE, 
+                       OpOp2.CENTRALMOMENT, OpOp2.MINUS1_MULT, OpOp2.MINUS_NZ, 
OpOp2.LOG_NZ)) 
                {
                        BinaryOp bin = (BinaryOp) hi.getInput().get(0);
                        BinaryOp bout = null;

Reply via email to