This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 8d4b928975 [SYSTEMDS-3508] Remove Spark Checkpoint if CP
8d4b928975 is described below
commit 8d4b928975c1cc0742ced432e90646ae22abfc9a
Author: baunsgaard <[email protected]>
AuthorDate: Thu Mar 16 17:51:42 2023 +0100
[SYSTEMDS-3508] Remove Spark Checkpoint if CP
Modify the read write rule to not introduce a spark checkpoint
if there only is one consuming operation remove the checkpoint.
Closes #1790
---
.gitignore | 1 +
src/main/java/org/apache/sysds/hops/Hop.java | 17 +++-
.../RewriteInjectSparkPReadCheckpointing.java | 70 +++++++------
.../test/functions/append/AppendChainTest.java | 2 +-
.../test/functions/append/AppendMatrixTest.java | 2 +-
.../test/functions/append/AppendVectorTest.java | 2 +-
.../functions/binary/matrix/MapMultChainTest.java | 31 +++---
.../test/functions/compress/CompressScale.java | 2 +-
.../functions/rewrite/RewriteSPCheckpoint.java | 110 +++++++++++++++++++++
.../rewrite/RewriteSlicedMatrixMultTest.java | 4 -
.../functions/ternary/ABATernaryAggregateTest.java | 3 +-
.../ternary/CTableMatrixIgnoreZerosTest.java | 6 +-
.../test/functions/ternary/CTableSequenceTest.java | 25 +++--
.../ternary/CentralMomentWeightsTest.java | 9 +-
.../functions/ternary/CovarianceWeightsTest.java | 9 +-
.../test/functions/ternary/FullIfElseTest.java | 6 +-
.../functions/ternary/QuantileWeightsTest.java | 6 +-
.../test/functions/ternary/TableOutputTest.java | 6 +-
.../functions/ternary/TernaryAggregateTest.java | 5 +-
.../rewrite/SPCheckpoint/RewriteSPCheckpoint.dml | 25 +++++
.../SPCheckpoint/RewriteSPCheckpointRemove.dml | 23 +++++
21 files changed, 273 insertions(+), 91 deletions(-)
diff --git a/.gitignore b/.gitignore
index 5504b55a1b..261535b8ba 100644
--- a/.gitignore
+++ b/.gitignore
@@ -82,6 +82,7 @@ src/test/scripts/**/expected
# Working directory and scratch space
temp/*
+temp
scratch_space/
# Ruby
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index 3119208f84..890948562d 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -64,6 +64,10 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.util.UtilFunctions;
+/**
+ * Hop is a High level operator, that is the first intermediate representation
compiled from the definitions supplied in
+ * DML.
+ */
public abstract class Hop implements ParseInfo {
protected static final Log LOG =
LogFactory.getLog(Hop.class.getName());
@@ -81,11 +85,15 @@ public abstract class Hop implements ParseInfo {
protected PrivacyConstraint _privacyConstraint = null;
protected UpdateType _updateType = UpdateType.COPY;
+ /** The output Hops that are connected to this Hop */
protected ArrayList<Hop> _parent = new ArrayList<>();
+ /** The input Hops that are connected to this Hop */
protected ArrayList<Hop> _input = new ArrayList<>();
- protected ExecType _etype = null; //currently used exec type
- protected ExecType _etypeForced = null; //exec type forced via platform
or external optimizer
+ /** Currently used exec type */
+ protected ExecType _etype = null;
+ /** Exec type forced via platform or external optimizer */
+ protected ExecType _etypeForced = null;
/**
* Field defining if the output of the operation should be federated.
@@ -93,8 +101,11 @@ public abstract class Hop implements ParseInfo {
* If it is lout, the output should be retrieved by the coordinator.
*/
protected FederatedOutput _federatedOutput = FederatedOutput.NONE;
+ /** The Federated Cost of this Hop */
protected FederatedCost _federatedCost = new FederatedCost();
+ /** The estimated number of repetitions of this Hop*/
protected double repetitions = 1;
+ /** Boolean specifying if the repetition count is updated/assigned */
protected boolean repetitionsUpdated = false;
/**
@@ -105,7 +116,7 @@ public abstract class Hop implements ParseInfo {
*/
protected boolean activatePrefetch;
- // Estimated size for the output produced from this Hop in bytes
+ /** Estimated size for the output produced from this Hop in bytes */
protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
// Estimated size for the entire operation represented by this Hop
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
index 6ba8ae1d85..6156d4d1d4 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
@@ -20,6 +20,7 @@
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
+import java.util.List;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.DataOp;
@@ -27,61 +28,72 @@ import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
/**
- * Rule: BlockSizeAndReblock. For all statement blocks, determine
- * "optimal" block size, and place reblock Hops. For now, we just
- * use BlockSize 1K x 1K and do reblock after Persistent Reads and
- * before Persistent Writes.
+ * Rule: Inject checkpointing on reading in data in all cases where the
operand is used in more than one operation.
*/
-public class RewriteInjectSparkPReadCheckpointing extends HopRewriteRule
-{
+public class RewriteInjectSparkPReadCheckpointing extends HopRewriteRule {
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state) {
- if( !OptimizerUtils.isSparkExecutionMode() )
+ if(!OptimizerUtils.isSparkExecutionMode())
return roots;
-
- if( roots == null )
+
+ if(roots == null)
return null;
- //top-level hops never modified
- for( Hop h : roots )
+ // top-level hops never modified
+ for(Hop h : roots)
rInjectCheckpointAfterPRead(h);
-
+
return roots;
}
@Override
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
- //not applicable to predicates (we do not allow persistent
reads there)
+ // not applicable to predicates (we do not allow persistent
reads there)
return root;
}
- private void rInjectCheckpointAfterPRead( Hop hop )
- {
+ private void rInjectCheckpointAfterPRead(Hop hop) {
if(hop.isVisited())
return;
-
// Inject checkpoints after persistent reads (for binary
matrices only), or
// after reblocks that cause expensive shuffling. However,
carefully avoid
- // unnecessary frame checkpoints (e.g., binary data or csv that
do not cause
+ // unnecessary frame checkpoints (e.g., binary data or csv that
do not cause
// shuffle) in order to prevent excessive garbage collection
due to possibly
// many small string objects. An alternative would be
serialized caching.
boolean isMatrix = hop.getDataType().isMatrix();
- boolean isPRead = hop instanceof DataOp &&
((DataOp)hop).getOp()==OpOpData.PERSISTENTREAD;
- boolean isFrameException = hop.getDataType().isFrame() &&
isPRead && !((DataOp)hop).getFileFormat().isIJV();
-
- if( (isMatrix && isPRead) || (hop.requiresReblock() &&
!isFrameException) ) {
- //make given hop for checkpointing (w/ default storage
level)
- //note: we do not recursively process childs here in
order to prevent unnecessary checkpoints
- hop.setRequiresCheckpoint(true);
+ boolean isPRead = hop instanceof DataOp && ((DataOp)
hop).getOp() == OpOpData.PERSISTENTREAD;
+ boolean isFrameException = hop.getDataType().isFrame() &&
isPRead && !((DataOp) hop).getFileFormat().isIJV();
+
+ // if the only operation performed is an action then do not add
chkpoint
+ if((isMatrix && isPRead) || (hop.requiresReblock() &&
!isFrameException)) {
+ boolean isActionOnly = isActionOnly(hop,
hop.getParent());
+ // make given hop for checkpointing (w/ default storage
level)
+ // note: we do not recursively process children here in
order to prevent unnecessary checkpoints
+
+ if(!isActionOnly)
+ hop.setRequiresCheckpoint(true);
+
}
else {
- if( hop.getInput() != null ) {
- //process all childs (prevent concurrent
modification by index access)
- for( int i=0; i<hop.getInput().size(); i++ )
- rInjectCheckpointAfterPRead(
hop.getInput().get(i) );
+ if(hop.getInput() != null) {
+ // process all children (prevent concurrent
modification by index access)
+ for(int i = 0; i < hop.getInput().size(); i++)
+
rInjectCheckpointAfterPRead(hop.getInput().get(i));
}
}
-
+
hop.setVisited();
}
+
+ private boolean isActionOnly(Hop hop, List<Hop> parents) {
+ // if the number of consumers of this hop is equal to 1 and no
more
+ // then do not cache block unless that one operation is
transient write
+ if(parents.size() == 1) {
+ return !( parents.get(0) instanceof DataOp && //
+ ((DataOp) parents.get(0)).getOp() ==
OpOpData.TRANSIENTWRITE);
+ }
+ else
+ return false;
+
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/append/AppendChainTest.java
b/src/test/java/org/apache/sysds/test/functions/append/AppendChainTest.java
index f6ef0568eb..70fed4a3d8 100644
--- a/src/test/java/org/apache/sysds/test/functions/append/AppendChainTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/append/AppendChainTest.java
@@ -136,7 +136,7 @@ public class AppendChainTest extends AutomatedTestBase
writeInputMatrix("B2", B2, true);
int expectedCompiled = platform==ExecMode.SINGLE_NODE ?
- 0 : 8; //3x(rblk+chkpt), append, write
+ 0 : 5; //3x(rblk), append, write
runTest(true, false, null, expectedCompiled);
runRScript(true);
diff --git
a/src/test/java/org/apache/sysds/test/functions/append/AppendMatrixTest.java
b/src/test/java/org/apache/sysds/test/functions/append/AppendMatrixTest.java
index 434b002960..41f9e3a310 100644
--- a/src/test/java/org/apache/sysds/test/functions/append/AppendMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/append/AppendMatrixTest.java
@@ -178,7 +178,7 @@ public class AppendMatrixTest extends AutomatedTestBase
writeInputMatrix("B", B, true);
int expectedCompiled = platform==ExecMode.SINGLE_NODE ?
- 0 : 6; //2x(rblk+chkpt), append, write
+ 0 : 4; //2x(rblk), append, write
runTest(true, false, null, expectedCompiled);
runRScript(true);
diff --git
a/src/test/java/org/apache/sysds/test/functions/append/AppendVectorTest.java
b/src/test/java/org/apache/sysds/test/functions/append/AppendVectorTest.java
index 8752d78507..cf97df58a8 100644
--- a/src/test/java/org/apache/sysds/test/functions/append/AppendVectorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/append/AppendVectorTest.java
@@ -106,7 +106,7 @@ public class AppendVectorTest extends AutomatedTestBase
writeInputMatrix("B", B, true);
boolean exceptionExpected = false;
- int numExpectedJobs = (platform ==
ExecMode.SINGLE_NODE) ? 0 : 6;
+ int numExpectedJobs = (platform ==
ExecMode.SINGLE_NODE) ? 0 : 4;
runTest(true, exceptionExpected, null, numExpectedJobs);
Assert.assertEquals("Wrong number of executed Spark
jobs.",
diff --git
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/MapMultChainTest.java
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/MapMultChainTest.java
index cae033fe5d..824d3e7500 100644
---
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/MapMultChainTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/MapMultChainTest.java
@@ -21,19 +21,24 @@ package org.apache.sysds.test.functions.binary.matrix;
import java.util.HashMap;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.rewrite.RewriteSPCheckpoint;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class MapMultChainTest extends AutomatedTestBase {
+
+ protected static final Log LOG =
LogFactory.getLog(RewriteSPCheckpoint.class.getName());
-public class MapMultChainTest extends AutomatedTestBase
-{
private final static String TEST_NAME1 = "MapMultChain";
private final static String TEST_NAME2 = "MapMultChainWeights";
private final static String TEST_NAME3 = "MapMultChainWeights2";
@@ -199,6 +204,7 @@ public class MapMultChainTest extends AutomatedTestBase
private void runMapMultChainTest( String testname, boolean sparse,
boolean sumProductRewrites, ExecType instType)
{
+ setOutputBuffering(true);
ExecMode platformOld = setExecMode(instType);
//rewrite
@@ -214,7 +220,7 @@ public class MapMultChainTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[]{"-stats","-args",
+ programArgs = new String[]{"-explain","-stats","-args",
input("X"), input("v"), input("w"),
output("R")};
fullRScriptName = HOME + TEST_NAME + ".R";
@@ -230,7 +236,7 @@ public class MapMultChainTest extends AutomatedTestBase
writeInputMatrixWithMTD("w", w, true);
}
- runTest(true, false, null, -1);
+ runTest(null);
runRScript(true);
//compare matrices
@@ -239,9 +245,12 @@ public class MapMultChainTest extends AutomatedTestBase
TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
//check compiled/executed jobs
- int numInputs = testname.equals(TEST_NAME1) ? 2 : 3;
- int expectedNumCompiled = numInputs +
((instType==ExecType.SPARK) ?
- (numInputs +
(sumProductRewrites?2:((numInputs==2)?4:5))):0);
+ int numInputs = (testname.equals(TEST_NAME1) ? 2 : 3 );
+ int expectedNumCompiled = numInputs +
+ // If spark
+ // The +1 can be removed if we update the
checkpointing rewrite to go thorugh blocks
+ //
/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
+ ((instType == ExecType.SPARK) ? (1+
(sumProductRewrites ? 2 : ((numInputs==2)?4:5))):0);
checkNumCompiledSparkInst(expectedNumCompiled);
checkNumExecutedSparkInst(expectedNumCompiled
- ((instType==ExecType.CP)?numInputs:0));
diff --git
a/src/test/java/org/apache/sysds/test/functions/compress/CompressScale.java
b/src/test/java/org/apache/sysds/test/functions/compress/CompressScale.java
index 2926472a83..0360eed20a 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/CompressScale.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/CompressScale.java
@@ -76,7 +76,7 @@ public class CompressScale extends AutomatedTestBase {
int min, int max, int scale, int center) {
Types.ExecMode platformOld = setExecMode(instType);
- setOutputBuffering(true); //otherwise test fails in local
+ setOutputBuffering(true);
try {
fullDMLScriptName = SCRIPT_DIR + "/" + getTestDir() +
getTestName() + ".dml";
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSPCheckpoint.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSPCheckpoint.java
new file mode 100644
index 0000000000..8cb52e78b2
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSPCheckpoint.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class RewriteSPCheckpoint extends AutomatedTestBase {
+
+ protected static final Log LOG =
LogFactory.getLog(RewriteSPCheckpoint.class.getName());
+ private static final String TEST_NAME1 = "RewriteSPCheckpoint";
+ private static final String TEST_NAME2 = "RewriteSPCheckpointRemove";
+
+ private static final String TEST_DIR =
"functions/rewrite/SPCheckpoint/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteSPCheckpoint.class.getSimpleName() + "/";
+
+ private static final int dim1 = 1324;
+ private static final int dim2 = 1100;
+
+ private static final double sparsity = 0.7;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2));
+ }
+
+ @Test
+ public void testRewriteCheckpointTransientWrite() {
+ testRewriteCheckpoint(TEST_NAME1, true);
+ }
+
+ @Test
+ public void testRewriteCheckpointTransientWriteRemove() {
+ testRewriteCheckpoint(TEST_NAME2, false);
+ }
+
+ private void testRewriteCheckpoint(String testName, boolean rewrite) {
+ setOutputBuffering(true);
+ Types.ExecMode platformOld = setExecMode(ExecMode.SPARK);
+ try {
+
+ TestConfiguration config =
getTestConfiguration(testName);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[] {"-explain", "-stats",
"-args", input("A")};
+ write(TestUtils.generateTestMatrixBlock(dim1, dim2, -1,
1, sparsity, 6), input("A"));
+ runTest(null);
+
+ assertTrue(rewrite ==
heavyHittersContainsString("sp_chkpoint"));
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ finally {
+
+ rtplatform = platformOld;
+ }
+ }
+
+ public void write(MatrixBlock mb, String path) {
+ try {
+ MatrixWriter w =
MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY);
+ w.writeMatrixToHDFS(mb, path, mb.getNumRows(),
mb.getNumColumns(), 1000, mb.getNonZeros());
+ MatrixCharacteristics mc = new
MatrixCharacteristics(mb.getNumRows(), mb.getNumColumns(), 1000,
+ mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(path + ".mtd",
ValueType.FP64, mc, FileFormat.BINARY);
+ }
+ catch(Exception e) {
+ fail(e.getMessage());
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSlicedMatrixMultTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSlicedMatrixMultTest.java
index 6ab4ae0352..13bbeffea6 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSlicedMatrixMultTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSlicedMatrixMultTest.java
@@ -28,10 +28,6 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-/**
- *
- *
- */
public class RewriteSlicedMatrixMultTest extends AutomatedTestBase
{
diff --git
a/src/test/java/org/apache/sysds/test/functions/ternary/ABATernaryAggregateTest.java
b/src/test/java/org/apache/sysds/test/functions/ternary/ABATernaryAggregateTest.java
index 261fc8dac4..59b24e9da1 100644
---
a/src/test/java/org/apache/sysds/test/functions/ternary/ABATernaryAggregateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ternary/ABATernaryAggregateTest.java
@@ -316,6 +316,7 @@ public class ABATernaryAggregateTest extends
AutomatedTestBase
private void runTernaryAggregateTest(String testname, boolean sparse,
boolean vectors, boolean rewrites, ExecType et)
{
+ setOutputBuffering(true);
//rtplatform for MR
ExecMode platformOld = rtplatform;
switch( et ){
@@ -350,7 +351,7 @@ public class ABATernaryAggregateTest extends
AutomatedTestBase
writeInputMatrixWithMTD("A", A, true);
//run test cases
- runTest(true, false, null, -1);
+ runTest(null);
runRScript(true);
//compare output matrices
diff --git
a/src/test/java/org/apache/sysds/test/functions/ternary/CTableMatrixIgnoreZerosTest.java
b/src/test/java/org/apache/sysds/test/functions/ternary/CTableMatrixIgnoreZerosTest.java
index 34ee9e36f2..fdbb0735b2 100644
---
a/src/test/java/org/apache/sysds/test/functions/ternary/CTableMatrixIgnoreZerosTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ternary/CTableMatrixIgnoreZerosTest.java
@@ -111,8 +111,8 @@ public class CTableMatrixIgnoreZerosTest extends
AutomatedTestBase
runCTableTest(false, true, ExecType.CP);
}
- private void runCTableTest( boolean rewrite, boolean sparse, ExecType
et)
- {
+ private void runCTableTest( boolean rewrite, boolean sparse, ExecType
et){
+ setOutputBuffering(true);
String TEST_NAME = TEST_NAME1;
//rtplatform for MR
@@ -149,7 +149,7 @@ public class CTableMatrixIgnoreZerosTest extends
AutomatedTestBase
double[][] A = getRandomMatrix(rows, cols, 1, 10,
sparsity, 7);
writeInputMatrixWithMTD("A", A, true);
- runTest(true, false, null, -1);
+ runTest(null);
runRScript(true);
//compare matrices
diff --git
a/src/test/java/org/apache/sysds/test/functions/ternary/CTableSequenceTest.java
b/src/test/java/org/apache/sysds/test/functions/ternary/CTableSequenceTest.java
index 21450012b9..b643390787 100644
---
a/src/test/java/org/apache/sysds/test/functions/ternary/CTableSequenceTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ternary/CTableSequenceTest.java
@@ -21,16 +21,18 @@ package org.apache.sysds.test.functions.ternary;
import java.util.HashMap;
-import org.junit.Assert;
-import org.junit.Test;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
/**
* This test investigates the specific Hop-Lop rewrite
ctable(seq(1,nrow(X)),X).
@@ -44,8 +46,10 @@ import org.apache.sysds.test.TestUtils;
* matrix cell.
*
*/
-public class CTableSequenceTest extends AutomatedTestBase
-{
+public class CTableSequenceTest extends AutomatedTestBase {
+
+ protected static final Log LOG =
LogFactory.getLog(CTableSequenceTest.class.getName());
+
private final static String TEST_NAME1 = "CTableSequenceLeft";
private final static String TEST_NAME2 = "CTableSequenceRight";
@@ -144,8 +148,9 @@ public class CTableSequenceTest extends AutomatedTestBase
runCTableSequenceTest(true, false, true, ExecType.CP);
}
- private void runCTableSequenceTest(boolean rewrite, boolean left,
boolean withAgg, ExecType et)
- {
+ private void runCTableSequenceTest(boolean rewrite, boolean left,
boolean withAgg, ExecType et){
+ setOutputBuffering(true);
+
String TEST_NAME = left ? TEST_NAME1 : TEST_NAME2;
ExecMode platformOld = rtplatform;
boolean rewriteOld = TernaryOp.ALLOW_CTABLE_SEQUENCE_REWRITES;
@@ -168,7 +173,7 @@ public class CTableSequenceTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[]{"-stats","-args", input("A"),
+ programArgs = new String[]{"-explain","-stats","-args",
input("A"),
Integer.toString(rows),
Integer.toString(1),
Integer.toString(withAgg?1:0),
@@ -181,7 +186,7 @@ public class CTableSequenceTest extends AutomatedTestBase
double[][] A = TestUtils.floor(getRandomMatrix(rows, 1,
1, maxVal, 1.0, 7));
writeInputMatrix("A", A, true);
- runTest(true, false, null, -1);
+ runTest(null);
runRScript(true);
//compare matrices
@@ -191,7 +196,7 @@ public class CTableSequenceTest extends AutomatedTestBase
//w/ rewrite: 4 instead of 6 because seq and
aggregation are not required for ctable_expand
//2 for CP due to reblock jobs for input and table
- int expectedNumCompiled = ((et==ExecType.CP) ? 2 :
5)+(withAgg ? 1 : 0);
+ int expectedNumCompiled = ((et==ExecType.CP) ? 2 :
4)+(withAgg ? 1 : 0);
checkNumCompiledSparkInst(expectedNumCompiled);
Assert.assertEquals(left & rewrite,
heavyHittersContainsSubString("ctableexpand"));
diff --git
a/src/test/java/org/apache/sysds/test/functions/ternary/CentralMomentWeightsTest.java
b/src/test/java/org/apache/sysds/test/functions/ternary/CentralMomentWeightsTest.java
index e30de9343f..617bc6d8a6 100644
---
a/src/test/java/org/apache/sysds/test/functions/ternary/CentralMomentWeightsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ternary/CentralMomentWeightsTest.java
@@ -125,15 +125,10 @@ public class CentralMomentWeightsTest extends
AutomatedTestBase
{
runCentralMomentTest(4, true, ExecType.SPARK);
}
-
- /**
- *
- * @param sparseM1
- * @param sparseM2
- * @param instType
- */
+
private void runCentralMomentTest( int order, boolean sparse, ExecType
et)
{
+ setOutputBuffering(true);
//rtplatform for MR
ExecMode platformOld = rtplatform;
switch( et ){
diff --git
a/src/test/java/org/apache/sysds/test/functions/ternary/CovarianceWeightsTest.java
b/src/test/java/org/apache/sysds/test/functions/ternary/CovarianceWeightsTest.java
index 0b6f2af934..5379259b4d 100644
---
a/src/test/java/org/apache/sysds/test/functions/ternary/CovarianceWeightsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ternary/CovarianceWeightsTest.java
@@ -78,14 +78,9 @@ public class CovarianceWeightsTest extends AutomatedTestBase
runCovarianceTest(true, ExecType.SPARK);
}
- /**
- *
- * @param sparseM1
- * @param sparseM2
- * @param instType
- */
private void runCovarianceTest( boolean sparse, ExecType et)
{
+ setOutputBuffering(true);
//rtplatform for MR
ExecMode platformOld = rtplatform;
switch( et ){
@@ -121,7 +116,7 @@ public class CovarianceWeightsTest extends AutomatedTestBase
double[][] C = getRandomMatrix(rows, 1, 1, 1, 1.0,
8623);
writeInputMatrixWithMTD("C", C, true);
- runTest(true, false, null, -1);
+ runTest(null);
runRScript(true);
//compare matrices
diff --git
a/src/test/java/org/apache/sysds/test/functions/ternary/FullIfElseTest.java
b/src/test/java/org/apache/sysds/test/functions/ternary/FullIfElseTest.java
index 2ae9093aa6..09306eb5fd 100644
--- a/src/test/java/org/apache/sysds/test/functions/ternary/FullIfElseTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/ternary/FullIfElseTest.java
@@ -212,8 +212,8 @@ public class FullIfElseTest extends AutomatedTestBase
runIfElseTest(true, true, true, true, ExecType.SPARK);
}
- private void runIfElseTest(boolean matrix1, boolean matrix2, boolean
matrix3, boolean sparse, ExecType et)
- {
+ private void runIfElseTest(boolean matrix1, boolean matrix2, boolean
matrix3, boolean sparse, ExecType et){
+ setOutputBuffering(true);
//rtplatform for MR
ExecMode platformOld = rtplatform;
switch( et ){
@@ -248,7 +248,7 @@ public class FullIfElseTest extends AutomatedTestBase
writeInputMatrixWithMTD("C", C, true);
//run test cases
- runTest(true, false, null, -1);
+ runTest(null);
runRScript(true);
//compare output matrices
diff --git
a/src/test/java/org/apache/sysds/test/functions/ternary/QuantileWeightsTest.java
b/src/test/java/org/apache/sysds/test/functions/ternary/QuantileWeightsTest.java
index 811b874d8f..6aed3b9648 100644
---
a/src/test/java/org/apache/sysds/test/functions/ternary/QuantileWeightsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ternary/QuantileWeightsTest.java
@@ -153,8 +153,8 @@ public class QuantileWeightsTest extends AutomatedTestBase
runQuantileTest(TEST_NAME3, -1, true, ExecType.SPARK);
}
- private void runQuantileTest( String TEST_NAME, double p, boolean
sparse, ExecType et)
- {
+ private void runQuantileTest( String TEST_NAME, double p, boolean
sparse, ExecType et){
+ setOutputBuffering(true);
//rtplatform for MR
ExecMode platformOld = rtplatform;
switch( et ){
@@ -186,7 +186,7 @@ public class QuantileWeightsTest extends AutomatedTestBase
double[][] W = getRandomMatrix(rows, 1, 1, 1, 1.0, 1);
writeInputMatrixWithMTD("W", W, true);
- runTest(true, false, null, -1);
+ runTest(null);
runRScript(true);
//compare matrices
diff --git
a/src/test/java/org/apache/sysds/test/functions/ternary/TableOutputTest.java
b/src/test/java/org/apache/sysds/test/functions/ternary/TableOutputTest.java
index d13f7e0d00..3b17c0385a 100644
--- a/src/test/java/org/apache/sysds/test/functions/ternary/TableOutputTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/ternary/TableOutputTest.java
@@ -86,8 +86,8 @@ public class TableOutputTest extends AutomatedTestBase
runTableOutputTest(ExecType.CP, -5);
}
- private void runTableOutputTest( ExecType et, int delta)
- {
+ private void runTableOutputTest(ExecType et, int delta) {
+ setOutputBuffering(true);
//rtplatform for MR
ExecMode platformOld = rtplatform;
@@ -124,7 +124,7 @@ public class TableOutputTest extends AutomatedTestBase
double[][] B = TestUtils.floor(getRandomMatrix(rows, 1,
1, maxVal2, 1.0, -1));
writeInputMatrixWithMTD("B", B, true);
- runTest(true, false, null, -1);
+ runTest(null);
runRScript(true);
//compare matrices
diff --git
a/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
b/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
index a87908aba0..a7224d88b0 100644
---
a/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
@@ -181,10 +181,9 @@ public class TernaryAggregateTest extends AutomatedTestBase
runTernaryAggregateTest(TEST_NAME2, true, false, false,
ExecType.CP);
}
-
-
private void runTernaryAggregateTest(String testname, boolean sparse,
boolean vectors, boolean rewrites, ExecType et)
{
+ setOutputBuffering(true);
//rtplatform for MR
ExecMode platformOld = rtplatform;
switch( et ){
@@ -220,7 +219,7 @@ public class TernaryAggregateTest extends AutomatedTestBase
writeInputMatrixWithMTD("A", A, true);
//run test cases
- runTest(true, false, null, -1);
+ runTest(null);
runRScript(true);
//compare output matrices
diff --git
a/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpoint.dml
b/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpoint.dml
new file mode 100644
index 0000000000..f7eaa0e9fb
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpoint.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+while(FALSE){}
+print(sum(A))
+print(mean(A))
diff --git
a/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpointRemove.dml
b/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpointRemove.dml
new file mode 100644
index 0000000000..9a8605a145
--- /dev/null
+++
b/src/test/scripts/functions/rewrite/SPCheckpoint/RewriteSPCheckpointRemove.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+print(sum(A))