[SYSTEMML-1439] Improved codegen row-aggregate candidate exploration

This patch generalizes the existing candidate exploration algorithm to
enable merging of partial rowagg templates. Together with a couple of
minor fixes and cleanups, this allows us now fusing the following
expression (from Kmeans) into a single operator.

Y = (X <= rowMins(X));
Z = (Y / rowSums(Y));
R = colSums(Z);

Note that the first row aggregate and row comparison consume the
original rows whereas the subsequent row aggregate and element-wise
division work over temporary row vectors (with internal reuse of
thread-local temporary row vectors).


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/5de7beea
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/5de7beea
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/5de7beea

Branch: refs/heads/master
Commit: 5de7beea2f5d4b9d6c9a8f7f3ae152b7442cf923
Parents: 18ab98a
Author: Matthias Boehm <[email protected]>
Authored: Fri Apr 7 21:23:58 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Apr 7 21:23:58 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       |  3 +-
 .../sysml/hops/codegen/cplan/CNodeUnary.java    |  7 +++-
 .../hops/codegen/template/TemplateUtils.java    |  5 +++
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 43 +++++---------------
 .../functions/codegen/AlgorithmLinregCG.java    |  2 +-
 5 files changed, 24 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5de7beea/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 2e60732..3dfb452 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -439,7 +439,8 @@ public class SpoofCompiler
                        if( k != pos ) {
                                Hop input2 = hop.getInput().get(k);
                                if( memo.contains(input2.getHopID()) && 
!memo.get(input2.getHopID()).get(0).closed
-                                       && 
memo.get(input2.getHopID()).get(0).type == TemplateType.CellTpl && 
tpl.merge(hop, input2) ) 
+                                       && 
TemplateUtils.isType(memo.get(input2.getHopID()).get(0).type, tpl.getType(), 
TemplateType.CellTpl)
+                                       && tpl.merge(hop, input2) ) 
                                        P.crossProduct(k, -1L, 
input2.getHopID());
                                else
                                        P.crossProduct(k, -1L);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5de7beea/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index 262295c..025033b 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -130,8 +130,9 @@ public class CNodeUnary extends CNode
                sb.append(_inputs.get(0).codegen(sparse));
                
                //generate unary operation
+               boolean lsparse = sparse && (_inputs.get(0) instanceof 
CNodeData);
                String var = createVarname();
-               String tmp = _type.getTemplate(sparse);
+               String tmp = _type.getTemplate(lsparse);
                tmp = tmp.replaceAll("%TMP%", var);
                
                String varj = _inputs.get(0).getVarname();
@@ -142,7 +143,9 @@ public class CNodeUnary extends CNode
                tmp = tmp.replaceAll("%IN1%", varj );
                
                //replace start position of main input
-               String spos = !varj.startsWith("b") ? varj+"i" : "0";
+               String spos = (!varj.startsWith("b") 
+                       && _inputs.get(0) instanceof CNodeData 
+                       && _inputs.get(0).getDataType().isMatrix()) ? varj+"i" 
: "0";
                tmp = tmp.replaceAll("%POS1%", spos);
                tmp = tmp.replaceAll("%POS2%", spos);
                

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5de7beea/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index c6a259f..e8d2086 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -25,6 +25,7 @@ import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedHashSet;
 
+import org.apache.commons.lang.ArrayUtils;
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
@@ -294,4 +295,8 @@ public class TemplateUtils
                return ret + ((node instanceof CNodeBinary 
                        && 
((CNodeBinary)node).getType().isVectorScalarPrimitive()) ? 1 : 0);
        }
+
+       public static boolean isType(TemplateType type, TemplateType... 
validTypes) {
+               return ArrayUtils.contains(validTypes, type);
+       }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5de7beea/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index fcfc14b..a4b6ec1 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -22,6 +22,7 @@ package org.apache.sysml.hops.rewrite;
 import java.util.ArrayList;
 import java.util.HashMap;
 
+import org.apache.commons.lang.ArrayUtils;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.conf.ConfigurationManager;
@@ -1066,46 +1067,24 @@ public class HopRewriteUtils
        //////////////////////////////////////
        // utils for lookup tables
        
-       public static boolean isValidOp( AggOp input, AggOp[] validTab )
-       {
-               for( AggOp valid : validTab )
-                       if( valid == input )
-                               return true;
-               return false;
+       public static boolean isValidOp( AggOp input, AggOp[] validTab ) {
+               return ArrayUtils.contains(validTab, input);
        }
        
-       public static boolean isValidOp( OpOp1 input, OpOp1[] validTab )
-       {
-               for( OpOp1 valid : validTab )
-                       if( valid == input )
-                               return true;
-               return false;
+       public static boolean isValidOp( OpOp1 input, OpOp1[] validTab ) {
+               return ArrayUtils.contains(validTab, input);
        }
        
-       public static boolean isValidOp( OpOp2 input, OpOp2[] validTab )
-       {
-               for( OpOp2 valid : validTab )
-                       if( valid == input )
-                               return true;
-               return false;
+       public static boolean isValidOp( OpOp2 input, OpOp2[] validTab ) {
+               return ArrayUtils.contains(validTab, input);
        }
        
-       public static boolean isValidOp( ReOrgOp input, ReOrgOp[] validTab )
-       {
-               for( ReOrgOp valid : validTab )
-                       if( valid == input )
-                               return true;
-               return false;
+       public static boolean isValidOp( ReOrgOp input, ReOrgOp[] validTab ) {
+               return ArrayUtils.contains(validTab, input);
        }
        
-       public static int getValidOpPos( OpOp2 input, OpOp2[] validTab )
-       {
-               for( int i=0; i<validTab.length; i++ ) {
-                        OpOp2 valid = validTab[i];
-                        if( valid == input )
-                                       return i;
-               }
-               return -1;
+       public static int getValidOpPos( OpOp2 input, OpOp2[] validTab ) {
+               return ArrayUtils.indexOf(validTab, input);
        }
        
        /**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/5de7beea/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java
index 6e3549e..dacc6ee 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java
@@ -41,7 +41,7 @@ public class AlgorithmLinregCG extends AutomatedTestBase
        private final static String TEST_CONF = "SystemML-config-codegen.xml";
        private final static File   TEST_CONF_FILE = new File(SCRIPT_DIR + 
TEST_DIR, TEST_CONF);
        
-       private final static double eps = 1e-5;
+       private final static double eps = 1e-1;
        
        private final static int rows = 2468;
        private final static int cols = 507;

Reply via email to