Repository: systemml
Updated Branches:
  refs/heads/master 8084dc127 -> 1bcdfaac1


[SYSTEMML-2134] Fix codegen row tmpl support for vector ternary axpy

This patch fixes the CPlan construction of row templates for ternary
axpy operations with row vector intermediates. Specifically, we now
correctly handle index loopkups only for scalar intermediates, which
otherwise causes codegen compilation errors.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1bcdfaac
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1bcdfaac
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1bcdfaac

Branch: refs/heads/master
Commit: 1bcdfaac138d8f68e4144ca4ddbaf8cf03329ca1
Parents: 8084dc1
Author: Matthias Boehm <[email protected]>
Authored: Mon Jun 4 14:25:05 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Jun 4 17:45:07 2018 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/codegen/cplan/CNodeBinary.java |  7 +++----
 .../apache/sysml/hops/codegen/cplan/CNodeTernary.java    |  8 +++-----
 .../org/apache/sysml/hops/codegen/cplan/CNodeUnary.java  |  7 +++----
 .../apache/sysml/hops/codegen/template/TemplateRow.java  | 11 ++++++-----
 4 files changed, 15 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/1bcdfaac/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index 80a7f83..7ef21c1 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysml.hops.codegen.cplan;
 
+import java.util.Arrays;
+
 import org.apache.commons.lang.StringUtils;
 import org.apache.sysml.hops.codegen.template.TemplateUtils;
 import org.apache.sysml.parser.Expression.DataType;
@@ -56,10 +58,7 @@ public class CNodeBinary extends CNode
                MINUS1_MULT, MINUS_NZ;
 
                public static boolean contains(String value) {
-                       for( BinType bt : values()  )
-                               if( bt.name().equals(value) )
-                                       return true;
-                       return false;
+                       return Arrays.stream(values()).anyMatch(bt -> 
bt.name().equals(value));
                }
                
                public boolean isCommutative() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/1bcdfaac/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
index dc8ff82..61140b4 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysml.hops.codegen.cplan;
 
+import java.util.Arrays;
+
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
@@ -30,12 +32,8 @@ public class CNodeTernary extends CNode
                REPLACE, REPLACE_NAN, IFELSE,
                LOOKUP_RC1, LOOKUP_RVECT1;
                
-               
                public static boolean contains(String value) {
-                       for( TernaryType tt : values()  )
-                               if( tt.name().equals(value) )
-                                       return true;
-                       return false;
+                       return Arrays.stream(values()).anyMatch(tt -> 
tt.name().equals(value));
                }
                
                public String getTemplate(boolean sparse) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/1bcdfaac/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index d7721a1..b269139 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysml.hops.codegen.cplan;
 
+import java.util.Arrays;
+
 import org.apache.commons.lang.ArrayUtils;
 import org.apache.commons.lang.StringUtils;
 import org.apache.sysml.hops.codegen.template.TemplateUtils;
@@ -43,10 +45,7 @@ public class CNodeUnary extends CNode
                SPROP, SIGMOID; 
                
                public static boolean contains(String value) {
-                       for( UnaryType ut : values()  )
-                               if( ut.name().equals(value) )
-                                       return true;
-                       return false;
+                       return Arrays.stream(values()).anyMatch(ut -> 
ut.name().equals(value));
                }
                
                public String getTemplate(boolean sparse) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/1bcdfaac/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index ed68e75..737b30d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -448,18 +448,19 @@ public class TemplateRow extends TemplateBase
                        CNode cdata2 = 
tmp.get(hop.getInput().get(1).getHopID());
                        CNode cdata3 = 
tmp.get(hop.getInput().get(2).getHopID());
                        
-                       //add lookups if required
-                       cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, 
hop.getInput().get(0));
-                       cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, 
hop.getInput().get(2));
-                       
                        if( hop.getDim2() > 2 ) { //row vectors
                                out = new CNodeBinary(cdata1, new 
CNodeBinary(cdata2, cdata3, BinType.VECT_MULT_SCALAR),
                                        top.getOp()==OpOp3.PLUS_MULT? 
BinType.VECT_PLUS : BinType.VECT_MINUS);
                        }
                        else {
+                               //add lookups if required
+                               cdata1 = 
TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
+                               cdata2 = 
TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
+                               cdata3 = 
TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
+                               
                                //construct scalar ternary cnode, primitive 
operation derived from OpOp3 
                                out = new CNodeTernary(cdata1, cdata2, cdata3, 
-                                       
TernaryType.valueOf(top.getOp().toString()));
+                                       
TernaryType.valueOf(top.getOp().name()));
                        }
                }
                else if(HopRewriteUtils.isNary(hop, OpOpN.CBIND)) {

Reply via email to