[SYSTEMML-1853] Fix rewrite matrix-scalar ops (incomplete blacklist)

This patch fixes the simplification rewrite for matrix-scalar binary
operations, which did had an incomplete backlist. This issue caused our
existing StepLinregDS algorithm to fail with dynamic recompilation
issues of the solve operation, which is not supported over matrix-scalar
but was not in the blacklist yet. We now use a complete whitelist
instead to make this decision explicit.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1adfc726
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1adfc726
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1adfc726

Branch: refs/heads/master
Commit: 1adfc72662601bb3acd750106f08bc8ed88dfcd2
Parents: 8c87d2a
Author: Matthias Boehm <[email protected]>
Authored: Thu Aug 17 23:35:15 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Aug 18 14:15:46 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/BinaryOp.java    |  6 ++--
 .../RewriteAlgebraicSimplificationStatic.java   | 18 ++++++++----
 .../org/apache/sysml/lops/BinaryScalar.java     | 31 ++++++++------------
 3 files changed, 26 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/1adfc726/src/main/java/org/apache/sysml/hops/BinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java 
b/src/main/java/org/apache/sysml/hops/BinaryOp.java
index 4a450ad..54c06f7 100644
--- a/src/main/java/org/apache/sysml/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java
@@ -557,12 +557,10 @@ public class BinaryOp extends Hop
                DataType dt2 = getInput().get(1).getDataType();
                
                if (dt1 == dt2 && dt1 == DataType.SCALAR) {
-
                        // Both operands scalar
                        BinaryScalar binScalar1 = new 
BinaryScalar(getInput().get(0)
-                                       .constructLops(),
-                                       getInput().get(1).constructLops(), 
HopsOpOp2LopsBS
-                                                       .get(op), 
getDataType(), getValueType());
+                               
.constructLops(),getInput().get(1).constructLops(),
+                               HopsOpOp2LopsBS.get(op), getDataType(), 
getValueType());
                        binScalar1.getOutputParameters().setDimensions(0, 0, 0, 
0, -1);
                        setLineNumbers(binScalar1);
                        setLops(binScalar1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/1adfc726/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index c010bc2..eadf492 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -65,12 +65,19 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
        private static final Log LOG = 
LogFactory.getLog(RewriteAlgebraicSimplificationStatic.class.getName());
        
        //valid aggregation operation types for rowOp to colOp conversions and 
vice versa
-       private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new 
AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR};
+       private static final AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new 
AggOp[] {
+               AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, 
AggOp.VAR};
        
        //valid binary operations for distributive and associate reorderings
-       private static OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new 
OpOp2[]{OpOp2.PLUS, OpOp2.MINUS}; 
-       private static OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new 
OpOp2[]{OpOp2.PLUS, OpOp2.MULT}; 
-               
+       private static final OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new 
OpOp2[] {OpOp2.PLUS, OpOp2.MINUS}; 
+       private static final OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new 
OpOp2[] {OpOp2.PLUS, OpOp2.MULT};
+       
+       //valid binary operations for scalar operations
+       private static final OpOp2[] LOOKUP_VALID_SCALAR_BINARY = new OpOp2[] 
{OpOp2.AND, OpOp2.DIV, 
+               OpOp2.EQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.INTDIV, 
OpOp2.LESS, OpOp2.LESSEQUAL, 
+               OpOp2.LOG, OpOp2.MAX, OpOp2.MIN, OpOp2.MINUS, OpOp2.MODULUS, 
OpOp2.MULT, OpOp2.NOTEQUAL, 
+               OpOp2.OR, OpOp2.PLUS, OpOp2.POW};
+       
        @Override
        public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) 
                throws HopsException
@@ -852,8 +859,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                
                if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)  
                        && hi.getInput().get(0) instanceof BinaryOp
-                       && !HopRewriteUtils.isBinary(hi.getInput().get(0), 
OpOp2.QUANTILE, 
-                       OpOp2.CENTRALMOMENT, OpOp2.MINUS1_MULT, OpOp2.MINUS_NZ, 
OpOp2.LOG_NZ)) 
+                       && HopRewriteUtils.isBinary(hi.getInput().get(0), 
LOOKUP_VALID_SCALAR_BINARY)) 
                {
                        BinaryOp bin = (BinaryOp) hi.getInput().get(0);
                        BinaryOp bout = null;

http://git-wip-us.apache.org/repos/asf/systemml/blob/1adfc726/src/main/java/org/apache/sysml/lops/BinaryScalar.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/BinaryScalar.java 
b/src/main/java/org/apache/sysml/lops/BinaryScalar.java
index 7169101..a2c10a9 100644
--- a/src/main/java/org/apache/sysml/lops/BinaryScalar.java
+++ b/src/main/java/org/apache/sysml/lops/BinaryScalar.java
@@ -30,10 +30,8 @@ import org.apache.sysml.parser.Expression.*;
  * Lop to perform binary scalar operations. Both inputs must be scalars.
  * Example i = j + k, i = i + 1. 
  */
-
 public class BinaryScalar extends Lop 
-{      
-       
+{
        public enum OperationTypes {
                ADD, SUBTRACT, MULTIPLY, DIVIDE, MODULUS, INTDIV,
                LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, 
GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS,
@@ -42,7 +40,7 @@ public class BinaryScalar extends Lop
                IQSIZE,
        }
        
-       OperationTypes operation;
+       private final OperationTypes operation;
        
        /**
         * Constructor to perform a scalar operation
@@ -66,7 +64,7 @@ public class BinaryScalar extends Lop
                boolean aligner = false;
                boolean definesMRJob = false;
                lps.addCompatibility(JobType.INVALID);
-               this.lps.setProperties(inputs, ExecType.CP, 
ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+               lps.setProperties(inputs, ExecType.CP, 
ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
        }
 
        @Override
@@ -74,23 +72,19 @@ public class BinaryScalar extends Lop
                return "Operation: " + operation;
        }
        
-       public OperationTypes getOperationType(){
+       public OperationTypes getOperationType() {
                return operation;
        }
 
        @Override
        public String getInstructions(String input1, String input2, String 
output) throws LopsException
        {
-               String opString = getOpcode( operation );
-               
-               
-               
                StringBuilder sb = new StringBuilder();
                
                sb.append(getExecType());
                sb.append(Lop.OPERAND_DELIMITOR);
                
-               sb.append( opString );
+               sb.append( getOpcode(operation) );
                sb.append( OPERAND_DELIMITOR );
                
                sb.append( 
getInputs().get(0).prepScalarInputOperand(getExecType()) );
@@ -105,17 +99,15 @@ public class BinaryScalar extends Lop
        }
        
        @Override
-       public Lop.SimpleInstType getSimpleInstructionType()
-       {
-               switch (operation){
- 
-               default:
-                       return SimpleInstType.Scalar;
-               }
+       public Lop.SimpleInstType getSimpleInstructionType() {
+               return SimpleInstType.Scalar;
        }
        
        public static String getOpcode( OperationTypes op )
        {
+               if( op == null )
+                       throw new UnsupportedOperationException("Unable to get 
opcode for 'null'.");
+               
                switch ( op ) 
                {
                        /* Arithmetic */
@@ -169,7 +161,8 @@ public class BinaryScalar extends Lop
                                return "iqsize"; 
                                
                        default:
-                               throw new 
UnsupportedOperationException("Instruction is not defined for BinaryScalar 
operator: " + op);
+                               throw new 
UnsupportedOperationException("Instruction "
+                                       + "is not defined for BinaryScalar 
operator: " + op);
                }
        }
 }

Reply via email to