Repository: systemml
Updated Branches:
  refs/heads/master 0eff9f28d -> 3cbd9d5ab


[SYSTEMML-2498] Fix codegen compiler for cbind w/ vectors and scalars

This patch fixes the codegen compiler for binary and nary cbind
operations to (1) not compile row templates for cbind operations with
row vectors, and (2) robustness for a mix of matrix and colunm vector
inputs, where the column vectors become scalars in the context of a row
template.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3cbd9d5a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3cbd9d5a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3cbd9d5a

Branch: refs/heads/master
Commit: 3cbd9d5ab0e9cd29b4e67183129deaa549c10d30
Parents: 0eff9f2
Author: Matthias Boehm <[email protected]>
Authored: Sat Oct 27 20:57:37 2018 +0200
Committer: Matthias Boehm <[email protected]>
Committed: Sat Oct 27 20:57:37 2018 +0200

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeNary.java     | 23 ++++++++++-------
 .../hops/codegen/template/TemplateRow.java      | 14 ++++++-----
 .../test/integration/AutomatedTestBase.java     | 26 ++++++++++----------
 3 files changed, 35 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/3cbd9d5a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
index 28e47f4..1a717d3 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
@@ -53,15 +53,20 @@ public class CNodeNary extends CNode
                                                boolean sparseInput = sparseGen 
&& input instanceof CNodeData
                                                        && 
input.getVarname().startsWith("a");
                                                String varj = 
input.getVarname();
-                                               String pos = (input instanceof 
CNodeData && input.getDataType().isMatrix()) ? 
-                                                               
(!varj.startsWith("b")) ? varj+"i" : TemplateUtils.isMatrix(input) ? 
-                                                               varj + 
".pos(rix)" : "0" : "0";
-                                               sb.append( sparseInput ?
-                                                       "    
LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, "
-                                                               +varj+"ix, 
"+pos+", "+off+", "+input._cols+");\n" :
-                                                       "    
LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj)
-                                                               +", %TMP%, 
"+pos+", "+off+", "+input._cols+");\n");
-                                               off += input._cols;
+                                               if( 
input.getDataType()==DataType.MATRIX ) {
+                                                       String pos = (input 
instanceof CNodeData) ?
+                                                               
!varj.startsWith("b") ? varj+"i" : varj + ".pos(rix)" : "0";
+                                                       sb.append( sparseInput ?
+                                                               "    
LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, "
+                                                                       
+varj+"ix, "+pos+", "+off+", "+input._cols+");\n" :
+                                                               "    
LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj)
+                                                                       +", 
%TMP%, "+pos+", "+off+", "+input._cols+");\n");
+                                                       off += input._cols;     
+                                               }
+                                               else { //e.g., col vectors -> 
scalars
+                                                       sb.append("    
%TMP%["+off+"] = "+varj+";\n");
+                                                       off ++;
+                                               }
                                        }
                                        return sb.toString();
                                case VECT_MAX_POOL:

http://git-wip-us.apache.org/repos/asf/systemml/blob/3cbd9d5a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index d9da27b..79213eb 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -92,9 +92,8 @@ public class TemplateRow extends TemplateBase
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
                        || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp)
                                && TemplateCell.isValidOperation(hop) && 
hop.getDim1() > 1)
-                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
                        || HopRewriteUtils.isTernary(hop, OpOp3.PLUS_MULT, 
OpOp3.MINUS_MULT)
-                       || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+                       || isValidBinaryNaryCBind(hop)
                        || (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) 
&& hop.isMatrix())
                        || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
hop.getDim2()==1 //MV
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
@@ -125,8 +124,7 @@ public class TemplateRow extends TemplateBase
        public boolean fuse(Hop hop, Hop input) {
                return !isClosed() && 
                        (  (hop instanceof BinaryOp && 
isValidBinaryOperation(hop)) 
-                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
-                       || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+                       || isValidBinaryNaryCBind(hop)
                        || (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) 
&& hop.isMatrix())
                        || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp) 
                                && TemplateCell.isValidOperation(hop))
@@ -156,8 +154,7 @@ public class TemplateRow extends TemplateBase
                return !isClosed() &&
                        ((hop instanceof BinaryOp && isValidBinaryOperation(hop)
                                && hop.getDim1() > 1 && input.getDim1()>1)
-                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
-                       || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && 
hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+                       || isValidBinaryNaryCBind(hop)
                        || (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) 
&& hop.isMatrix())
                        || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, 
OpOpDnn.BIASMULT)
                                && hop.getInput().get(0).dimsKnown() && 
hop.getInput().get(1).dimsKnown()
@@ -191,6 +188,11 @@ public class TemplateRow extends TemplateBase
                return TemplateUtils.isOperationSupported(hop);
        }
        
+       private static boolean isValidBinaryNaryCBind(Hop hop) {
+               return (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) || 
HopRewriteUtils.isNary(hop, OpOpN.CBIND))
+                       && hop.getInput().get(0).isMatrix() && hop.dimsKnown() 
&& hop.getInput().get(0).getDim1()>1;
+       }
+       
        private static boolean isFuseSkinnyMatrixMult(Hop hop) {
                //check for fusable but not opening matrix multiply 
(vect_outer-mult)
                Hop in1 = hop.getInput().get(0); //transpose

http://git-wip-us.apache.org/repos/asf/systemml/blob/3cbd9d5a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java 
b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index 2135f45..e3576ab 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -213,19 +213,19 @@ public abstract class AutomatedTestBase
        
        protected RUNTIME_PLATFORM setRuntimePlatform(ExecType et) {
                RUNTIME_PLATFORM platformOld = rtplatform;
-        switch (et) {
-            case MR:
-                rtplatform = RUNTIME_PLATFORM.HADOOP;
-                break;
-            case SPARK: {
-                rtplatform = RUNTIME_PLATFORM.SPARK;
-                DMLScript.USE_LOCAL_SPARK_CONFIG = true; // Always use local 
config for junit tests
-                break;
-            }
-            default:
-                rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
-                break;
-        }
+               switch (et) {
+                       case MR:
+                               rtplatform = RUNTIME_PLATFORM.HADOOP;
+                               break;
+                       case SPARK: {
+                               rtplatform = RUNTIME_PLATFORM.SPARK;
+                               DMLScript.USE_LOCAL_SPARK_CONFIG = true; // 
Always use local config for junit tests
+                               break;
+                       }
+                       default:
+                               rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
+                               break;
+               }
                return platformOld;
        }
 

Reply via email to