[SYSTEMML-1853] Fix rewrite matrix-scalar ops (incomplete blacklist) This patch fixes the simplification rewrite for matrix-scalar binary operations, which did had an incomplete backlist. This issue caused our existing StepLinregDS algorithm to fail with dynamic recompilation issues of the solve operation, which is not supported over matrix-scalar but was not in the blacklist yet. We now use a complete whitelist instead to make this decision explicit.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1adfc726 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1adfc726 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1adfc726 Branch: refs/heads/master Commit: 1adfc72662601bb3acd750106f08bc8ed88dfcd2 Parents: 8c87d2a Author: Matthias Boehm <[email protected]> Authored: Thu Aug 17 23:35:15 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Aug 18 14:15:46 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/BinaryOp.java | 6 ++-- .../RewriteAlgebraicSimplificationStatic.java | 18 ++++++++---- .../org/apache/sysml/lops/BinaryScalar.java | 31 ++++++++------------ 3 files changed, 26 insertions(+), 29 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/1adfc726/src/main/java/org/apache/sysml/hops/BinaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java index 4a450ad..54c06f7 100644 --- a/src/main/java/org/apache/sysml/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java @@ -557,12 +557,10 @@ public class BinaryOp extends Hop DataType dt2 = getInput().get(1).getDataType(); if (dt1 == dt2 && dt1 == DataType.SCALAR) { - // Both operands scalar BinaryScalar binScalar1 = new BinaryScalar(getInput().get(0) - .constructLops(), - getInput().get(1).constructLops(), HopsOpOp2LopsBS - .get(op), getDataType(), getValueType()); + .constructLops(),getInput().get(1).constructLops(), + HopsOpOp2LopsBS.get(op), getDataType(), getValueType()); binScalar1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(binScalar1); setLops(binScalar1); http://git-wip-us.apache.org/repos/asf/systemml/blob/1adfc726/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index c010bc2..eadf492 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -65,12 +65,19 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule private static final Log LOG = LogFactory.getLog(RewriteAlgebraicSimplificationStatic.class.getName()); //valid aggregation operation types for rowOp to colOp conversions and vice versa - private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR}; + private static final AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[] { + AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR}; //valid binary operations for distributive and associate reorderings - private static OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MINUS}; - private static OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT}; - + private static final OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new OpOp2[] {OpOp2.PLUS, OpOp2.MINUS}; + private static final OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new OpOp2[] {OpOp2.PLUS, OpOp2.MULT}; + + //valid binary operations for scalar operations + private static final OpOp2[] LOOKUP_VALID_SCALAR_BINARY = new OpOp2[] {OpOp2.AND, OpOp2.DIV, + OpOp2.EQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL, OpOp2.INTDIV, OpOp2.LESS, OpOp2.LESSEQUAL, + OpOp2.LOG, OpOp2.MAX, OpOp2.MIN, OpOp2.MINUS, OpOp2.MODULUS, OpOp2.MULT, OpOp2.NOTEQUAL, + OpOp2.OR, OpOp2.PLUS, OpOp2.POW}; + @Override public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException @@ -852,8 +859,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) && hi.getInput().get(0) instanceof BinaryOp - && !HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.QUANTILE, - OpOp2.CENTRALMOMENT, OpOp2.MINUS1_MULT, OpOp2.MINUS_NZ, OpOp2.LOG_NZ)) + && HopRewriteUtils.isBinary(hi.getInput().get(0), LOOKUP_VALID_SCALAR_BINARY)) { BinaryOp bin = (BinaryOp) hi.getInput().get(0); BinaryOp bout = null; http://git-wip-us.apache.org/repos/asf/systemml/blob/1adfc726/src/main/java/org/apache/sysml/lops/BinaryScalar.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/BinaryScalar.java b/src/main/java/org/apache/sysml/lops/BinaryScalar.java index 7169101..a2c10a9 100644 --- a/src/main/java/org/apache/sysml/lops/BinaryScalar.java +++ b/src/main/java/org/apache/sysml/lops/BinaryScalar.java @@ -30,10 +30,8 @@ import org.apache.sysml.parser.Expression.*; * Lop to perform binary scalar operations. Both inputs must be scalars. * Example i = j + k, i = i + 1. */ - public class BinaryScalar extends Lop -{ - +{ public enum OperationTypes { ADD, SUBTRACT, MULTIPLY, DIVIDE, MODULUS, INTDIV, LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS, @@ -42,7 +40,7 @@ public class BinaryScalar extends Lop IQSIZE, } - OperationTypes operation; + private final OperationTypes operation; /** * Constructor to perform a scalar operation @@ -66,7 +64,7 @@ public class BinaryScalar extends Lop boolean aligner = false; boolean definesMRJob = false; lps.addCompatibility(JobType.INVALID); - this.lps.setProperties(inputs, ExecType.CP, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); + lps.setProperties(inputs, ExecType.CP, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); } @Override @@ -74,23 +72,19 @@ public class BinaryScalar extends Lop return "Operation: " + operation; } - public OperationTypes getOperationType(){ + public OperationTypes getOperationType() { return operation; } @Override public String getInstructions(String input1, String input2, String output) throws LopsException { - String opString = getOpcode( operation ); - - - StringBuilder sb = new StringBuilder(); sb.append(getExecType()); sb.append(Lop.OPERAND_DELIMITOR); - sb.append( opString ); + sb.append( getOpcode(operation) ); sb.append( OPERAND_DELIMITOR ); sb.append( getInputs().get(0).prepScalarInputOperand(getExecType()) ); @@ -105,17 +99,15 @@ public class BinaryScalar extends Lop } @Override - public Lop.SimpleInstType getSimpleInstructionType() - { - switch (operation){ - - default: - return SimpleInstType.Scalar; - } + public Lop.SimpleInstType getSimpleInstructionType() { + return SimpleInstType.Scalar; } public static String getOpcode( OperationTypes op ) { + if( op == null ) + throw new UnsupportedOperationException("Unable to get opcode for 'null'."); + switch ( op ) { /* Arithmetic */ @@ -169,7 +161,8 @@ public class BinaryScalar extends Lop return "iqsize"; default: - throw new UnsupportedOperationException("Instruction is not defined for BinaryScalar operator: " + op); + throw new UnsupportedOperationException("Instruction " + + "is not defined for BinaryScalar operator: " + op); } } }
