This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 8d4b928975 [SYSTEMDS-3508] Remove Spark Checkpoint if CP
8d4b928975 is described below

commit 8d4b928975c1cc0742ced432e90646ae22abfc9a
Author: baunsgaard <[email protected]>
AuthorDate: Thu Mar 16 17:51:42 2023 +0100

    [SYSTEMDS-3508] Remove Spark Checkpoint if CP
    
    Modify the read write rule to not introduce a spark checkpoint
    if there only is one consuming operation remove the checkpoint.
    
    Closes #1790
---
 .gitignore                                         |   1 +
 src/main/java/org/apache/sysds/hops/Hop.java       |  17 +++-
 .../RewriteInjectSparkPReadCheckpointing.java      |  70 +++++++------
 .../test/functions/append/AppendChainTest.java     |   2 +-
 .../test/functions/append/AppendMatrixTest.java    |   2 +-
 .../test/functions/append/AppendVectorTest.java    |   2 +-
 .../functions/binary/matrix/MapMultChainTest.java  |  31 +++---
 .../test/functions/compress/CompressScale.java     |   2 +-
 .../functions/rewrite/RewriteSPCheckpoint.java     | 110 +++++++++++++++++++++
 .../rewrite/RewriteSlicedMatrixMultTest.java       |   4 -
 .../functions/ternary/ABATernaryAggregateTest.java |   3 +-
 .../ternary/CTableMatrixIgnoreZerosTest.java       |   6 +-
 .../test/functions/ternary/CTableSequenceTest.java |  25 +++--
 .../ternary/CentralMomentWeightsTest.java          |   9 +-
 .../functions/ternary/CovarianceWeightsTest.java   |   9 +-
 .../test/functions/ternary/FullIfElseTest.java     |   6 +-
 .../functions/ternary/QuantileWeightsTest.java     |   6 +-
 .../test/functions/ternary/TableOutputTest.java    |   6 +-
 .../functions/ternary/TernaryAggregateTest.java    |   5 +-
 .../rewrite/SPCheckpoint/RewriteSPCheckpoint.dml   |  25 +++++
 .../SPCheckpoint/RewriteSPCheckpointRemove.dml     |  23 +++++
 21 files changed, 273 insertions(+), 91 deletions(-)

diff --git a/.gitignore b/.gitignore
index 5504b55a1b..261535b8ba 100644
--- a/.gitignore
+++ b/.gitignore
@@ -82,6 +82,7 @@ src/test/scripts/**/expected
 
 # Working directory and scratch space
 temp/*
+temp
 scratch_space/
 
 # Ruby
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index 3119208f84..890948562d 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -64,6 +64,10 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.runtime.util.UtilFunctions;
 
+/**
+ * Hop is a High level operator, that is the first intermediate representation 
compiled from the definitions supplied in
+ * DML.
+ */
 public abstract class Hop implements ParseInfo {
        protected static final Log LOG =  
LogFactory.getLog(Hop.class.getName());
 
@@ -81,11 +85,15 @@ public abstract class Hop implements ParseInfo {
        protected PrivacyConstraint _privacyConstraint = null;
        protected UpdateType _updateType = UpdateType.COPY;
 
+       /** The output Hops that are connected to this Hop */
        protected ArrayList<Hop> _parent = new ArrayList<>();
+       /** The input Hops that are connected to this Hop */
        protected ArrayList<Hop> _input = new ArrayList<>();
 
-       protected ExecType _etype = null; //currently used exec type
-       protected ExecType _etypeForced = null; //exec type forced via platform 
or external optimizer
+       /** Currently used exec type */
+       protected ExecType _etype = null; 
+       /** Exec type forced via platform or external optimizer */
+       protected ExecType _etypeForced = null; 
 
        /**
         * Field defining if the output of the operation should be federated.
@@ -93,8 +101,11 @@ public abstract class Hop implements ParseInfo {
         * If it is lout, the output should be retrieved by the coordinator.
         */
        protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
+       /** The Federated Cost of this Hop */
        protected FederatedCost _federatedCost = new FederatedCost();
+       /** The estimated number of repetitions of this Hop*/
        protected double repetitions = 1;
+       /** Boolean specifying if the repetition count is updated/assigned */
        protected boolean repetitionsUpdated = false;
 
        /**
@@ -105,7 +116,7 @@ public abstract class Hop implements ParseInfo {
         */
        protected boolean activatePrefetch;
        
-       // Estimated size for the output produced from this Hop in bytes
+       /** Estimated size for the output produced from this Hop in bytes */
        protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
        
        // Estimated size for the entire operation represented by this Hop
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
index 6ba8ae1d85..6156d4d1d4 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.hops.rewrite;
 
 import java.util.ArrayList;
+import java.util.List;
 
 import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.hops.DataOp;
@@ -27,61 +28,72 @@ import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.OptimizerUtils;
 
 /**
- * Rule: BlockSizeAndReblock. For all statement blocks, determine
- * "optimal" block size, and place reblock Hops. For now, we just
- * use BlockSize 1K x 1K and do reblock after Persistent Reads and
- * before Persistent Writes.
+ * Rule: Inject checkpointing on reading in data in all cases where the 
operand is used in more than one operation.
  */
-public class RewriteInjectSparkPReadCheckpointing extends HopRewriteRule
-{
+public class RewriteInjectSparkPReadCheckpointing extends HopRewriteRule {
        @Override
        public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
-               if(  !OptimizerUtils.isSparkExecutionMode()  ) 
+               if(!OptimizerUtils.isSparkExecutionMode())
                        return roots;
-               
-               if( roots == null )
+
+               if(roots == null)
                        return null;
 
-               //top-level hops never modified
-               for( Hop h : roots ) 
+               // top-level hops never modified
+               for(Hop h : roots)
                        rInjectCheckpointAfterPRead(h);
-               
+
                return roots;
        }
 
        @Override
        public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
-               //not applicable to predicates (we do not allow persistent 
reads there)
+               // not applicable to predicates (we do not allow persistent 
reads there)
                return root;
        }
 
-       private void rInjectCheckpointAfterPRead( Hop hop ) 
-       {
+       private void rInjectCheckpointAfterPRead(Hop hop) {
                if(hop.isVisited())
                        return;
-               
                // Inject checkpoints after persistent reads (for binary 
matrices only), or
                // after reblocks that cause expensive shuffling. However, 
carefully avoid
-               // unnecessary frame checkpoints (e.g., binary data or csv that 
do not cause 
+               // unnecessary frame checkpoints (e.g., binary data or csv that 
do not cause
                // shuffle) in order to prevent excessive garbage collection 
due to possibly
                // many small string objects. An alternative would be 
serialized caching.
                boolean isMatrix = hop.getDataType().isMatrix();
-               boolean isPRead = hop instanceof DataOp  && 
((DataOp)hop).getOp()==OpOpData.PERSISTENTREAD;
-               boolean isFrameException = hop.getDataType().isFrame() && 
isPRead && !((DataOp)hop).getFileFormat().isIJV();
-               
-               if( (isMatrix && isPRead) || (hop.requiresReblock() && 
!isFrameException) ) {
-                       //make given hop for checkpointing (w/ default storage 
level)
-                       //note: we do not recursively process childs here in 
order to prevent unnecessary checkpoints
-                       hop.setRequiresCheckpoint(true);
+               boolean isPRead = hop instanceof DataOp && ((DataOp) 
hop).getOp() == OpOpData.PERSISTENTREAD;
+               boolean isFrameException = hop.getDataType().isFrame() && 
isPRead && !((DataOp) hop).getFileFormat().isIJV();
+
+               // if the only operation performed is an action then do not add 
chkpoint
+               if((isMatrix && isPRead) || (hop.requiresReblock() && 
!isFrameException)) {
+                       boolean isActionOnly = isActionOnly(hop, 
hop.getParent());
+                       // make given hop for checkpointing (w/ default storage 
level)
+                       // note: we do not recursively process children here in 
order to prevent unnecessary checkpoints
+                       
+                       if(!isActionOnly)
+                               hop.setRequiresCheckpoint(true);
+                       
                }
                else {
-                       if( hop.getInput() != null ) {
-                               //process all childs (prevent concurrent 
modification by index access)
-                               for( int i=0; i<hop.getInput().size(); i++ )
-                                       rInjectCheckpointAfterPRead( 
hop.getInput().get(i) );
+                       if(hop.getInput() != null) {
+                               // process all children (prevent concurrent 
modification by index access)
+                               for(int i = 0; i < hop.getInput().size(); i++)
+                                       
rInjectCheckpointAfterPRead(hop.getInput().get(i));
                        }
                }
-               
+
                hop.setVisited();
        }
+
+       private boolean isActionOnly(Hop hop, List<Hop> parents) {
+               // if the number of consumers of this hop is equal to 1 and no 
more
+               // then do not cache block unless that one operation is 
transient write
+               if(parents.size() == 1) {
+                       return !( parents.get(0) instanceof DataOp && //
+                               ((DataOp) parents.get(0)).getOp() == 
OpOpData.TRANSIENTWRITE);
+               }
+               else
+                       return false;
+
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/append/AppendChainTest.java 
b/src/test/java/org/apache/sysds/test/functions/append/AppendChainTest.java
index f6ef0568eb..70fed4a3d8 100644
--- a/src/test/java/org/apache/sysds/test/functions/append/AppendChainTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/append/AppendChainTest.java
@@ -136,7 +136,7 @@ public class AppendChainTest extends AutomatedTestBase
                        writeInputMatrix("B2", B2, true);
        
                        int expectedCompiled = platform==ExecMode.SINGLE_NODE ?
-                               0 : 8; //3x(rblk+chkpt), append, write
+                               0 : 5; //3x(rblk), append, write
                        
                        runTest(true, false, null, expectedCompiled);
                        runRScript(true);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/append/AppendMatrixTest.java 
b/src/test/java/org/apache/sysds/test/functions/append/AppendMatrixTest.java
index 434b002960..41f9e3a310 100644
--- a/src/test/java/org/apache/sysds/test/functions/append/AppendMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/append/AppendMatrixTest.java
@@ -178,7 +178,7 @@ public class AppendMatrixTest extends AutomatedTestBase
                        writeInputMatrix("B", B, true);
                        
                        int expectedCompiled = platform==ExecMode.SINGLE_NODE ?
-                               0 : 6; //2x(rblk+chkpt), append, write
+                               0 : 4; //2x(rblk), append, write
                        
                        runTest(true, false, null, expectedCompiled);
                        runRScript(true);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/append/AppendVectorTest.java 
b/src/test/java/org/apache/sysds/test/functions/append/AppendVectorTest.java
index 8752d78507..cf97df58a8 100644
--- a/src/test/java/org/apache/sysds/test/functions/append/AppendVectorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/append/AppendVectorTest.java
@@ -106,7 +106,7 @@ public class AppendVectorTest extends AutomatedTestBase
                        writeInputMatrix("B", B, true);
                        
                        boolean exceptionExpected = false;
-                       int numExpectedJobs = (platform == 
ExecMode.SINGLE_NODE) ? 0 : 6;
+                       int numExpectedJobs = (platform == 
ExecMode.SINGLE_NODE) ? 0 : 4;
                        
                        runTest(true, exceptionExpected, null, numExpectedJobs);
                        Assert.assertEquals("Wrong number of executed Spark 
jobs.",
diff --git 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/MapMultChainTest.java
 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/MapMultChainTest.java
index cae033fe5d..824d3e7500 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/MapMultChainTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/MapMultChainTest.java
@@ -21,19 +21,24 @@ package org.apache.sysds.test.functions.binary.matrix;
 
 import java.util.HashMap;
 
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.rewrite.RewriteSPCheckpoint;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class MapMultChainTest extends AutomatedTestBase {
+
+       protected static final Log LOG = 
LogFactory.getLog(RewriteSPCheckpoint.class.getName());
 
-public class MapMultChainTest extends AutomatedTestBase 
-{
        private final static String TEST_NAME1 = "MapMultChain";
        private final static String TEST_NAME2 = "MapMultChainWeights";
        private final static String TEST_NAME3 = "MapMultChainWeights2";
@@ -199,6 +204,7 @@ public class MapMultChainTest extends AutomatedTestBase
        
        private void runMapMultChainTest( String testname, boolean sparse, 
boolean sumProductRewrites, ExecType instType)
        {
+               setOutputBuffering(true);
                ExecMode platformOld = setExecMode(instType);
                
                //rewrite
@@ -214,7 +220,7 @@ public class MapMultChainTest extends AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[]{"-stats","-args", 
+                       programArgs = new String[]{"-explain","-stats","-args", 
                                input("X"), input("v"), input("w"), 
output("R")};
                        
                        fullRScriptName = HOME + TEST_NAME + ".R";
@@ -230,7 +236,7 @@ public class MapMultChainTest extends AutomatedTestBase
                                writeInputMatrixWithMTD("w", w, true);
                        }
                        
-                       runTest(true, false, null, -1); 
+                       runTest(null);
                        runRScript(true); 
                        
                        //compare matrices 
@@ -239,9 +245,12 @@ public class MapMultChainTest extends AutomatedTestBase
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
                        
                        //check compiled/executed jobs
-                       int numInputs = testname.equals(TEST_NAME1) ? 2 : 3;
-                       int expectedNumCompiled = numInputs + 
((instType==ExecType.SPARK) ? 
-                               (numInputs + 
(sumProductRewrites?2:((numInputs==2)?4:5))):0);
+                       int numInputs = (testname.equals(TEST_NAME1) ? 2 : 3 );
+                       int expectedNumCompiled = numInputs + 
+                               // If spark
+                               // The +1 can be removed if we update the 
checkpointing rewrite to go thorugh blocks
+                               // 
/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java 
+                               ((instType == ExecType.SPARK) ? (1+ 
(sumProductRewrites ? 2 : ((numInputs==2)?4:5))):0);
                        checkNumCompiledSparkInst(expectedNumCompiled); 
                        checkNumExecutedSparkInst(expectedNumCompiled 
                                - ((instType==ExecType.CP)?numInputs:0));
diff --git 
a/src/test/java/org/apache/sysds/test/functions/compress/CompressScale.java 
b/src/test/java/org/apache/sysds/test/functions/compress/CompressScale.java
index 2926472a83..0360eed20a 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/CompressScale.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/CompressScale.java
@@ -76,7 +76,7 @@ public class CompressScale extends AutomatedTestBase {
                int min, int max, int scale, int center) {
 
                Types.ExecMode platformOld = setExecMode(instType);
-               setOutputBuffering(true); //otherwise test fails in local
+               setOutputBuffering(true);
                try {
 
                        fullDMLScriptName = SCRIPT_DIR + "/" + getTestDir() + 
getTestName() + ".dml";
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSPCheckpoint.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSPCheckpoint.java
new file mode 100644
index 0000000000..8cb52e78b2
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSPCheckpoint.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class RewriteSPCheckpoint extends AutomatedTestBase {
+
+       protected static final Log LOG = 
LogFactory.getLog(RewriteSPCheckpoint.class.getName());
+       private static final String TEST_NAME1 = "RewriteSPCheckpoint";
+       private static final String TEST_NAME2 = "RewriteSPCheckpointRemove";
+
+       private static final String TEST_DIR = 
"functions/rewrite/SPCheckpoint/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteSPCheckpoint.class.getSimpleName() + "/";
+
+       private static final int dim1 = 1324;
+       private static final int dim2 = 1100;
+
+       private static final double sparsity = 0.7;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2));
+       }
+
+       @Test
+       public void testRewriteCheckpointTransientWrite() {
+               testRewriteCheckpoint(TEST_NAME1, true);
+       }
+
+       @Test
+       public void testRewriteCheckpointTransientWriteRemove() {
+               testRewriteCheckpoint(TEST_NAME2, false);
+       }
+
+       private void testRewriteCheckpoint(String testName, boolean rewrite) {
+               setOutputBuffering(true);
+               Types.ExecMode platformOld = setExecMode(ExecMode.SPARK);
+               try {
+
+                       TestConfiguration config = 
getTestConfiguration(testName);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testName + ".dml";
+                       programArgs = new String[] {"-explain", "-stats", 
"-args", input("A")};
+                       write(TestUtils.generateTestMatrixBlock(dim1, dim2, -1, 
1, sparsity, 6), input("A"));
+                       runTest(null);
+
+                       assertTrue(rewrite == 
heavyHittersContainsString("sp_chkpoint"));
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+               finally {
+
+                       rtplatform = platformOld;
+               }
+       }
+
+       public void write(MatrixBlock mb, String path) {
+               try {
+                       MatrixWriter w = 
MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY);
+                       w.writeMatrixToHDFS(mb, path, mb.getNumRows(), 
mb.getNumColumns(), 1000, mb.getNonZeros());
+                       MatrixCharacteristics mc = new 
MatrixCharacteristics(mb.getNumRows(), mb.getNumColumns(), 1000,
+                               mb.getNonZeros());
+                       HDFSTool.writeMetaDataFile(path + ".mtd", 
ValueType.FP64, mc, FileFormat.BINARY);
+               }
+               catch(Exception e) {
+                       fail(e.getMessage());
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSlicedMatrixMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSlicedMatrixMultTest.java
index 6ab4ae0352..13bbeffea6 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSlicedMatrixMultTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSlicedMatrixMultTest.java
@@ -28,10 +28,6 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 
-/**
- * 
- * 
- */
 public class RewriteSlicedMatrixMultTest extends AutomatedTestBase 
 {
        
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ternary/ABATernaryAggregateTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ternary/ABATernaryAggregateTest.java
index 261fc8dac4..59b24e9da1 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/ternary/ABATernaryAggregateTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/ternary/ABATernaryAggregateTest.java
@@ -316,6 +316,7 @@ public class ABATernaryAggregateTest extends 
AutomatedTestBase
        
        private void runTernaryAggregateTest(String testname, boolean sparse, 
boolean vectors, boolean rewrites, ExecType et)
        {
+               setOutputBuffering(true);
                //rtplatform for MR
                ExecMode platformOld = rtplatform;
                switch( et ){
@@ -350,7 +351,7 @@ public class ABATernaryAggregateTest extends 
AutomatedTestBase
                        writeInputMatrixWithMTD("A", A, true);
                        
                        //run test cases
-                       runTest(true, false, null, -1); 
+                       runTest(null); 
                        runRScript(true); 
                        
                        //compare output matrices 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ternary/CTableMatrixIgnoreZerosTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ternary/CTableMatrixIgnoreZerosTest.java
index 34ee9e36f2..fdbb0735b2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/ternary/CTableMatrixIgnoreZerosTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/ternary/CTableMatrixIgnoreZerosTest.java
@@ -111,8 +111,8 @@ public class CTableMatrixIgnoreZerosTest extends 
AutomatedTestBase
                runCTableTest(false, true, ExecType.CP);
        }
        
-       private void runCTableTest( boolean rewrite, boolean sparse, ExecType 
et)
-       {
+       private void runCTableTest( boolean rewrite, boolean sparse, ExecType 
et){
+               setOutputBuffering(true); 
                String TEST_NAME = TEST_NAME1;
                
                //rtplatform for MR
@@ -149,7 +149,7 @@ public class CTableMatrixIgnoreZerosTest extends 
AutomatedTestBase
                        double[][] A = getRandomMatrix(rows, cols, 1, 10, 
sparsity, 7); 
                        writeInputMatrixWithMTD("A", A, true);
        
-                       runTest(true, false, null, -1); 
+                       runTest(null); 
                        runRScript(true); 
                        
                        //compare matrices 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ternary/CTableSequenceTest.java 
b/src/test/java/org/apache/sysds/test/functions/ternary/CTableSequenceTest.java
index 21450012b9..b643390787 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/ternary/CTableSequenceTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/ternary/CTableSequenceTest.java
@@ -21,16 +21,18 @@ package org.apache.sysds.test.functions.ternary;
 
 import java.util.HashMap;
 
-import org.junit.Assert;
-import org.junit.Test;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
 
 /**
  * This test investigates the specific Hop-Lop rewrite 
ctable(seq(1,nrow(X)),X).
@@ -44,8 +46,10 @@ import org.apache.sysds.test.TestUtils;
  *   matrix cell.
  * 
  */
-public class CTableSequenceTest extends AutomatedTestBase 
-{
+public class CTableSequenceTest extends AutomatedTestBase {
+
+       protected static final Log LOG = 
LogFactory.getLog(CTableSequenceTest.class.getName());
+
        private final static String TEST_NAME1 = "CTableSequenceLeft";
        private final static String TEST_NAME2 = "CTableSequenceRight";
        
@@ -144,8 +148,9 @@ public class CTableSequenceTest extends AutomatedTestBase
                runCTableSequenceTest(true, false, true, ExecType.CP);
        }
        
-       private void runCTableSequenceTest(boolean rewrite, boolean left, 
boolean withAgg, ExecType et)
-       {
+       private void runCTableSequenceTest(boolean rewrite, boolean left, 
boolean withAgg, ExecType et){
+               setOutputBuffering(true);
+
                String TEST_NAME = left ? TEST_NAME1 : TEST_NAME2;
                ExecMode platformOld = rtplatform;
                boolean rewriteOld = TernaryOp.ALLOW_CTABLE_SEQUENCE_REWRITES;
@@ -168,7 +173,7 @@ public class CTableSequenceTest extends AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[]{"-stats","-args", input("A"),
+                       programArgs = new String[]{"-explain","-stats","-args", 
input("A"),
                                Integer.toString(rows),
                                Integer.toString(1),
                                Integer.toString(withAgg?1:0),
@@ -181,7 +186,7 @@ public class CTableSequenceTest extends AutomatedTestBase
                        double[][] A = TestUtils.floor(getRandomMatrix(rows, 1, 
1, maxVal, 1.0, 7)); 
                        writeInputMatrix("A", A, true);
        
-                       runTest(true, false, null, -1);
+                       runTest(null);
                        runRScript(true);
                        
                        //compare matrices 
@@ -191,7 +196,7 @@ public class CTableSequenceTest extends AutomatedTestBase
                        
                        //w/ rewrite: 4 instead of 6 because seq and 
aggregation are not required for ctable_expand
                        //2 for CP due to reblock jobs for input and table
-                       int expectedNumCompiled = ((et==ExecType.CP) ? 2 : 
5)+(withAgg ? 1 : 0);
+                       int expectedNumCompiled = ((et==ExecType.CP) ? 2 : 
4)+(withAgg ? 1 : 0);
                        checkNumCompiledSparkInst(expectedNumCompiled);
                        Assert.assertEquals(left & rewrite,
                                heavyHittersContainsSubString("ctableexpand"));
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ternary/CentralMomentWeightsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ternary/CentralMomentWeightsTest.java
index e30de9343f..617bc6d8a6 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/ternary/CentralMomentWeightsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/ternary/CentralMomentWeightsTest.java
@@ -125,15 +125,10 @@ public class CentralMomentWeightsTest extends 
AutomatedTestBase
        {
                runCentralMomentTest(4, true, ExecType.SPARK);
        }
-       
-       /**
-        * 
-        * @param sparseM1
-        * @param sparseM2
-        * @param instType
-        */
+
        private void runCentralMomentTest( int order, boolean sparse, ExecType 
et)
        {
+               setOutputBuffering(true); 
                //rtplatform for MR
                ExecMode platformOld = rtplatform;
                switch( et ){
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ternary/CovarianceWeightsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ternary/CovarianceWeightsTest.java
index 0b6f2af934..5379259b4d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/ternary/CovarianceWeightsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/ternary/CovarianceWeightsTest.java
@@ -78,14 +78,9 @@ public class CovarianceWeightsTest extends AutomatedTestBase
                runCovarianceTest(true, ExecType.SPARK);
        }
        
-       /**
-        * 
-        * @param sparseM1
-        * @param sparseM2
-        * @param instType
-        */
        private void runCovarianceTest( boolean sparse, ExecType et)
        {
+               setOutputBuffering(true);
                //rtplatform for MR
                ExecMode platformOld = rtplatform;
                switch( et ){
@@ -121,7 +116,7 @@ public class CovarianceWeightsTest extends AutomatedTestBase
                        double[][] C = getRandomMatrix(rows, 1, 1, 1, 1.0, 
8623); 
                        writeInputMatrixWithMTD("C", C, true);  
                        
-                       runTest(true, false, null, -1); 
+                       runTest(null); 
                        runRScript(true); 
                        
                        //compare matrices 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ternary/FullIfElseTest.java 
b/src/test/java/org/apache/sysds/test/functions/ternary/FullIfElseTest.java
index 2ae9093aa6..09306eb5fd 100644
--- a/src/test/java/org/apache/sysds/test/functions/ternary/FullIfElseTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/ternary/FullIfElseTest.java
@@ -212,8 +212,8 @@ public class FullIfElseTest extends AutomatedTestBase
                runIfElseTest(true, true, true, true, ExecType.SPARK);
        }
 
-       private void runIfElseTest(boolean matrix1, boolean matrix2, boolean 
matrix3, boolean sparse, ExecType et)
-       {
+       private void runIfElseTest(boolean matrix1, boolean matrix2, boolean 
matrix3, boolean sparse, ExecType et){
+               setOutputBuffering(true);
                //rtplatform for MR
                ExecMode platformOld = rtplatform;
                switch( et ){
@@ -248,7 +248,7 @@ public class FullIfElseTest extends AutomatedTestBase
                        writeInputMatrixWithMTD("C", C, true);
                        
                        //run test cases
-                       runTest(true, false, null, -1);
+                       runTest(null);
                        runRScript(true); 
                        
                        //compare output matrices 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ternary/QuantileWeightsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ternary/QuantileWeightsTest.java
index 811b874d8f..6aed3b9648 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/ternary/QuantileWeightsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/ternary/QuantileWeightsTest.java
@@ -153,8 +153,8 @@ public class QuantileWeightsTest extends AutomatedTestBase
                runQuantileTest(TEST_NAME3, -1, true, ExecType.SPARK);
        }
        
-       private void runQuantileTest( String TEST_NAME, double p, boolean 
sparse, ExecType et)
-       {
+       private void runQuantileTest( String TEST_NAME, double p, boolean 
sparse, ExecType et){
+               setOutputBuffering(true);
                //rtplatform for MR
                ExecMode platformOld = rtplatform;
                switch( et ){
@@ -186,7 +186,7 @@ public class QuantileWeightsTest extends AutomatedTestBase
                        double[][] W = getRandomMatrix(rows, 1, 1, 1, 1.0, 1); 
                        writeInputMatrixWithMTD("W", W, true);
                        
-                       runTest(true, false, null, -1); 
+                       runTest(null); 
                        runRScript(true); 
                        
                        //compare matrices 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ternary/TableOutputTest.java 
b/src/test/java/org/apache/sysds/test/functions/ternary/TableOutputTest.java
index d13f7e0d00..3b17c0385a 100644
--- a/src/test/java/org/apache/sysds/test/functions/ternary/TableOutputTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/ternary/TableOutputTest.java
@@ -86,8 +86,8 @@ public class TableOutputTest extends AutomatedTestBase
                runTableOutputTest(ExecType.CP, -5);
        }
        
-       private void runTableOutputTest( ExecType et, int delta)
-       {
+       private void runTableOutputTest(ExecType et, int delta) {
+               setOutputBuffering(true);
                //rtplatform for MR
                ExecMode platformOld = rtplatform;
 
@@ -124,7 +124,7 @@ public class TableOutputTest extends AutomatedTestBase
                        double[][] B = TestUtils.floor(getRandomMatrix(rows, 1, 
1, maxVal2, 1.0, -1)); 
                        writeInputMatrixWithMTD("B", B, true);
                        
-                       runTest(true, false, null, -1); 
+                       runTest(null); 
                        runRScript(true); 
                        
                        //compare matrices 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
index a87908aba0..a7224d88b0 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
@@ -181,10 +181,9 @@ public class TernaryAggregateTest extends AutomatedTestBase
                runTernaryAggregateTest(TEST_NAME2, true, false, false, 
ExecType.CP);
        }
        
-       
-       
        private void runTernaryAggregateTest(String testname, boolean sparse, 
boolean vectors, boolean rewrites, ExecType et)
        {
+               setOutputBuffering(true);
                //rtplatform for MR
                ExecMode platformOld = rtplatform;
                switch( et ){
@@ -220,7 +219,7 @@ public class TernaryAggregateTest extends AutomatedTestBase
                        writeInputMatrixWithMTD("A", A, true);
                        
                        //run test cases
-                       runTest(true, false, null, -1); 
+                       runTest(null); 
                        runRScript(true); 
                        
                        //compare output matrices 
diff --git 
a/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpoint.dml 
b/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpoint.dml
new file mode 100644
index 0000000000..f7eaa0e9fb
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpoint.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+while(FALSE){}
+print(sum(A))
+print(mean(A))
diff --git 
a/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpointRemove.dml 
b/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpointRemove.dml
new file mode 100644
index 0000000000..9a8605a145
--- /dev/null
+++ 
b/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpointRemove.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+print(sum(A))


Reply via email to