Repository: systemml Updated Branches: refs/heads/master 8084dc127 -> 1bcdfaac1
[SYSTEMML-2134] Fix codegen row tmpl support for vector ternary axpy This patch fixes the CPlan construction of row templates for ternary axpy operations with row vector intermediates. Specifically, we now correctly handle index loopkups only for scalar intermediates, which otherwise causes codegen compilation errors. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1bcdfaac Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1bcdfaac Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1bcdfaac Branch: refs/heads/master Commit: 1bcdfaac138d8f68e4144ca4ddbaf8cf03329ca1 Parents: 8084dc1 Author: Matthias Boehm <[email protected]> Authored: Mon Jun 4 14:25:05 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Jun 4 17:45:07 2018 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/hops/codegen/cplan/CNodeBinary.java | 7 +++---- .../apache/sysml/hops/codegen/cplan/CNodeTernary.java | 8 +++----- .../org/apache/sysml/hops/codegen/cplan/CNodeUnary.java | 7 +++---- .../apache/sysml/hops/codegen/template/TemplateRow.java | 11 ++++++----- 4 files changed, 15 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/1bcdfaac/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index 80a7f83..7ef21c1 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -19,6 +19,8 @@ package org.apache.sysml.hops.codegen.cplan; +import java.util.Arrays; + import org.apache.commons.lang.StringUtils; import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.parser.Expression.DataType; @@ -56,10 +58,7 @@ public class CNodeBinary extends CNode MINUS1_MULT, MINUS_NZ; public static boolean contains(String value) { - for( BinType bt : values() ) - if( bt.name().equals(value) ) - return true; - return false; + return Arrays.stream(values()).anyMatch(bt -> bt.name().equals(value)); } public boolean isCommutative() { http://git-wip-us.apache.org/repos/asf/systemml/blob/1bcdfaac/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java index dc8ff82..61140b4 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java @@ -19,6 +19,8 @@ package org.apache.sysml.hops.codegen.cplan; +import java.util.Arrays; + import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.util.UtilFunctions; @@ -30,12 +32,8 @@ public class CNodeTernary extends CNode REPLACE, REPLACE_NAN, IFELSE, LOOKUP_RC1, LOOKUP_RVECT1; - public static boolean contains(String value) { - for( TernaryType tt : values() ) - if( tt.name().equals(value) ) - return true; - return false; + return Arrays.stream(values()).anyMatch(tt -> tt.name().equals(value)); } public String getTemplate(boolean sparse) { http://git-wip-us.apache.org/repos/asf/systemml/blob/1bcdfaac/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index d7721a1..b269139 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -19,6 +19,8 @@ package org.apache.sysml.hops.codegen.cplan; +import java.util.Arrays; + import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.StringUtils; import org.apache.sysml.hops.codegen.template.TemplateUtils; @@ -43,10 +45,7 @@ public class CNodeUnary extends CNode SPROP, SIGMOID; public static boolean contains(String value) { - for( UnaryType ut : values() ) - if( ut.name().equals(value) ) - return true; - return false; + return Arrays.stream(values()).anyMatch(ut -> ut.name().equals(value)); } public String getTemplate(boolean sparse) { http://git-wip-us.apache.org/repos/asf/systemml/blob/1bcdfaac/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index ed68e75..737b30d 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -448,18 +448,19 @@ public class TemplateRow extends TemplateBase CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID()); - //add lookups if required - cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); - cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2)); - if( hop.getDim2() > 2 ) { //row vectors out = new CNodeBinary(cdata1, new CNodeBinary(cdata2, cdata3, BinType.VECT_MULT_SCALAR), top.getOp()==OpOp3.PLUS_MULT? BinType.VECT_PLUS : BinType.VECT_MINUS); } else { + //add lookups if required + cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); + cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1)); + cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2)); + //construct scalar ternary cnode, primitive operation derived from OpOp3 out = new CNodeTernary(cdata1, cdata2, cdata3, - TernaryType.valueOf(top.getOp().toString())); + TernaryType.valueOf(top.getOp().name())); } } else if(HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {
