Repository: systemml Updated Branches: refs/heads/master 0eff9f28d -> 3cbd9d5ab
[SYSTEMML-2498] Fix codegen compiler for cbind w/ vectors and scalars This patch fixes the codegen compiler for binary and nary cbind operations to (1) not compile row templates for cbind operations with row vectors, and (2) robustness for a mix of matrix and colunm vector inputs, where the column vectors become scalars in the context of a row template. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3cbd9d5a Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3cbd9d5a Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3cbd9d5a Branch: refs/heads/master Commit: 3cbd9d5ab0e9cd29b4e67183129deaa549c10d30 Parents: 0eff9f2 Author: Matthias Boehm <[email protected]> Authored: Sat Oct 27 20:57:37 2018 +0200 Committer: Matthias Boehm <[email protected]> Committed: Sat Oct 27 20:57:37 2018 +0200 ---------------------------------------------------------------------- .../sysml/hops/codegen/cplan/CNodeNary.java | 23 ++++++++++------- .../hops/codegen/template/TemplateRow.java | 14 ++++++----- .../test/integration/AutomatedTestBase.java | 26 ++++++++++---------- 3 files changed, 35 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/3cbd9d5a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java index 28e47f4..1a717d3 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java @@ -53,15 +53,20 @@ public class CNodeNary extends CNode boolean sparseInput = sparseGen && input instanceof CNodeData && input.getVarname().startsWith("a"); String varj = input.getVarname(); - String pos = (input instanceof CNodeData && input.getDataType().isMatrix()) ? - (!varj.startsWith("b")) ? varj+"i" : TemplateUtils.isMatrix(input) ? - varj + ".pos(rix)" : "0" : "0"; - sb.append( sparseInput ? - " LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, " - +varj+"ix, "+pos+", "+off+", "+input._cols+");\n" : - " LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj) - +", %TMP%, "+pos+", "+off+", "+input._cols+");\n"); - off += input._cols; + if( input.getDataType()==DataType.MATRIX ) { + String pos = (input instanceof CNodeData) ? + !varj.startsWith("b") ? varj+"i" : varj + ".pos(rix)" : "0"; + sb.append( sparseInput ? + " LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, " + +varj+"ix, "+pos+", "+off+", "+input._cols+");\n" : + " LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj) + +", %TMP%, "+pos+", "+off+", "+input._cols+");\n"); + off += input._cols; + } + else { //e.g., col vectors -> scalars + sb.append(" %TMP%["+off+"] = "+varj+";\n"); + off ++; + } } return sb.toString(); case VECT_MAX_POOL: http://git-wip-us.apache.org/repos/asf/systemml/blob/3cbd9d5a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index d9da27b..79213eb 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -92,9 +92,8 @@ public class TemplateRow extends TemplateBase && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) || ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) && TemplateCell.isValidOperation(hop) && hop.getDim1() > 1) - || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) || HopRewriteUtils.isTernary(hop, OpOp3.PLUS_MULT, OpOp3.MINUS_MULT) - || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) + || isValidBinaryNaryCBind(hop) || (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix()) || (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 //MV && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) @@ -125,8 +124,7 @@ public class TemplateRow extends TemplateBase public boolean fuse(Hop hop, Hop input) { return !isClosed() && ( (hop instanceof BinaryOp && isValidBinaryOperation(hop)) - || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) - || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) + || isValidBinaryNaryCBind(hop) || (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix()) || ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) && TemplateCell.isValidOperation(hop)) @@ -156,8 +154,7 @@ public class TemplateRow extends TemplateBase return !isClosed() && ((hop instanceof BinaryOp && isValidBinaryOperation(hop) && hop.getDim1() > 1 && input.getDim1()>1) - || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) - || (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown()) + || isValidBinaryNaryCBind(hop) || (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix()) || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown() @@ -191,6 +188,11 @@ public class TemplateRow extends TemplateBase return TemplateUtils.isOperationSupported(hop); } + private static boolean isValidBinaryNaryCBind(Hop hop) { + return (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) || HopRewriteUtils.isNary(hop, OpOpN.CBIND)) + && hop.getInput().get(0).isMatrix() && hop.dimsKnown() && hop.getInput().get(0).getDim1()>1; + } + private static boolean isFuseSkinnyMatrixMult(Hop hop) { //check for fusable but not opening matrix multiply (vect_outer-mult) Hop in1 = hop.getInput().get(0); //transpose http://git-wip-us.apache.org/repos/asf/systemml/blob/3cbd9d5a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java index 2135f45..e3576ab 100644 --- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java @@ -213,19 +213,19 @@ public abstract class AutomatedTestBase protected RUNTIME_PLATFORM setRuntimePlatform(ExecType et) { RUNTIME_PLATFORM platformOld = rtplatform; - switch (et) { - case MR: - rtplatform = RUNTIME_PLATFORM.HADOOP; - break; - case SPARK: { - rtplatform = RUNTIME_PLATFORM.SPARK; - DMLScript.USE_LOCAL_SPARK_CONFIG = true; // Always use local config for junit tests - break; - } - default: - rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; - break; - } + switch (et) { + case MR: + rtplatform = RUNTIME_PLATFORM.HADOOP; + break; + case SPARK: { + rtplatform = RUNTIME_PLATFORM.SPARK; + DMLScript.USE_LOCAL_SPARK_CONFIG = true; // Always use local config for junit tests + break; + } + default: + rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; + break; + } return platformOld; }
