Repository: systemml
Updated Branches:
  refs/heads/master 341a1dc78 -> 8895ebc45


[SYSTEMML-2508] Improved spark cumagg compilation (single row block)

This patch improves the compilation of spark cumulative aggregates where
the input matrix has a single row block by avoiding the unnecessary
offset computation.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8895ebc4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8895ebc4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8895ebc4

Branch: refs/heads/master
Commit: 8895ebc454ce85e823d6332e40d7effd874e59df
Parents: 341a1dc
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sun Dec 16 16:04:01 2018 +0100
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sun Dec 16 17:07:01 2018 +0100

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/UnaryOp.java     | 39 +++++++++++++-------
 .../misc/RewriteCumulativeAggregatesTest.java   | 25 +++++++++++--
 .../misc/RewriteCumulativeAggregates.R          |  6 ++-
 3 files changed, 52 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/8895ebc4/src/main/java/org/apache/sysml/hops/UnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java 
b/src/main/java/org/apache/sysml/hops/UnaryOp.java
index 2952e85..77655de 100644
--- a/src/main/java/org/apache/sysml/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java
@@ -22,6 +22,7 @@ package org.apache.sysml.hops;
 import java.util.ArrayList;
 
 import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.lops.Aggregate;
 import org.apache.sysml.lops.Checkpoint;
 import org.apache.sysml.lops.Aggregate.OperationTypes;
@@ -455,8 +456,15 @@ public class UnaryOp extends MultiThreadedHop
                long bclen = input.getColsInBlock();
                boolean force = !dimsKnown() || _etypeForced == ExecType.SPARK;
                OperationTypes aggtype = getCumulativeAggType();
-               
                Lop X = input.constructLops();
+               
+               //special case single row block (no offsets needed)
+               if( rlen > 0 && clen > 0 && rlen <= brlen ) {
+                       Lop offset = HopRewriteUtils.createDataGenOpByVal(new 
LiteralOp(1),
+                               new LiteralOp(clen), 
getCumulativeInitValue()).constructLops();
+                       return constructCumOffBinary(X, offset, aggtype, rlen, 
clen, brlen, bclen);
+               }
+               
                Lop TEMP = X;
                ArrayList<Lop> DATA = new ArrayList<>();
                int level = 0;
@@ -497,22 +505,27 @@ public class UnaryOp extends MultiThreadedHop
                
                //split, group and mr cumsum
                while( level-- > 0  ) {
-                       //(for spark, the CumulativeOffsetBinary subsumes both 
the split aggregate and 
-                       //the subsequent offset binary apply of split 
aggregates against the original data)
-                       double initValue = getCumulativeInitValue();
-                       boolean broadcast = ALLOW_CUMAGG_BROADCAST
-                               && 
OptimizerUtils.checkSparkBroadcastMemoryBudget(OptimizerUtils.estimateSize(
-                               TEMP.getOutputParameters().getNumRows(), 
TEMP.getOutputParameters().getNumCols()));
-                       
-                       CumulativeOffsetBinary binary = new 
CumulativeOffsetBinary(DATA.get(level), TEMP, 
-                                       DataType.MATRIX, ValueType.DOUBLE, 
initValue, broadcast, aggtype, ExecType.SPARK);
-                       binary.getOutputParameters().setDimensions(rlen, clen, 
brlen, bclen, -1);
-                       setLineNumbers(binary);
-                       TEMP = binary;
+                       TEMP = constructCumOffBinary(DATA.get(level),
+                               TEMP, aggtype, rlen, clen, brlen, bclen);
                }
                
                return TEMP;
        }
+       
+       private Lop constructCumOffBinary(Lop data, Lop offset, OperationTypes 
aggtype, long rlen, long clen, long brlen, long bclen) {
+               //(for spark, the CumulativeOffsetBinary subsumes both the 
split aggregate and 
+               //the subsequent offset binary apply of split aggregates 
against the original data)
+               double initValue = getCumulativeInitValue();
+               boolean broadcast = ALLOW_CUMAGG_BROADCAST
+                       && 
OptimizerUtils.checkSparkBroadcastMemoryBudget(OptimizerUtils.estimateSize(
+                       offset.getOutputParameters().getNumRows(), 
offset.getOutputParameters().getNumCols()));
+               
+               CumulativeOffsetBinary binary = new 
CumulativeOffsetBinary(data, offset, 
+                               DataType.MATRIX, ValueType.DOUBLE, initValue, 
broadcast, aggtype, ExecType.SPARK);
+               binary.getOutputParameters().setDimensions(rlen, clen, brlen, 
bclen, -1);
+               setLineNumbers(binary);
+               return binary;
+       }
 
        private OperationTypes getCumulativeAggType() {
                switch( _op ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/8895ebc4/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
index da13502..9c7f9f8 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCumulativeAggregatesTest.java
@@ -24,19 +24,22 @@ import java.util.HashMap;
 import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
 import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
 
 public class RewriteCumulativeAggregatesTest extends AutomatedTestBase 
-{      
+{
        private static final String TEST_NAME = "RewriteCumulativeAggregates";
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteCumulativeAggregatesTest.class.getSimpleName() + "/";
        
        private static final int rows = 1234;
+       private static final int rows2 = 876;
        private static final int cols = 7;
        
        @Override
@@ -85,9 +88,19 @@ public class RewriteCumulativeAggregatesTest extends 
AutomatedTestBase
                testCumAggRewrite(4, true);
        }
        
-       private void testCumAggRewrite(int num, boolean rewrites)
+       @Test
+       public void testCumAggRewrite4SPSingleRowBlock() {
+               testCumAggRewrite(4, true, ExecType.SPARK);
+       }
+       
+       private void testCumAggRewrite(int num, boolean rewrites) {
+               testCumAggRewrite(num, rewrites, ExecType.CP);
+       }
+       
+       private void testCumAggRewrite(int num, boolean rewrites, ExecType et)
        {
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               RUNTIME_PLATFORM platformOld = setRuntimePlatform(et);
                
                try {
                        TestConfiguration config = 
getTestConfiguration(TEST_NAME);
@@ -95,14 +108,15 @@ public class RewriteCumulativeAggregatesTest extends 
AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[]{ "-stats", "-args",
+                       programArgs = new String[]{ "-explain","-stats", 
"-args",
                                input("A"), String.valueOf(num), output("R") };
                        rCmd = getRCmd(inputDir(), String.valueOf(num), 
expectedDir());
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
                        
                        //generate input data
-                       double[][] A = getRandomMatrix((num==4)?1:rows,
+                       double[][] A = getRandomMatrix((num==4)?
+                               et==ExecType.CP?1:rows2:rows,
                                (num==1)?rows:cols, -1, 1, 0.9, 7); 
                        writeInputMatrixWithMTD("A", A, true);
                        
@@ -118,8 +132,11 @@ public class RewriteCumulativeAggregatesTest extends 
AutomatedTestBase
                        //check applied rewrites
                        if( rewrites )
                                
Assert.assertTrue(!heavyHittersContainsString((num==2) ? "rev" : "ucumk+"));
+                       if( num==4 && et==ExecType.SPARK )
+                               
Assert.assertTrue(!heavyHittersContainsString("ucumk+","ucumack+"));
                }
                finally {
+                       rtplatform = platformOld;
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
                }
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/8895ebc4/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R 
b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
index f8a8576..390953b 100644
--- a/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
+++ b/src/test/scripts/functions/misc/RewriteCumulativeAggregates.R
@@ -37,7 +37,11 @@ if( num == 1 ) {
 } else if( num == 3 ) {
   R = t(as.matrix(colSums(apply(X, 2, cumsum))));
 } else if( num == 4 ) {
-  R = X;
+  if( nrow(X)==1 ) {
+    R = X;
+  } else {
+    R = apply(X, 2, cumsum);
+  }
 }
 
 writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); 

Reply via email to