This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 956a686c60 [SYSTEMDS-3473] Push down rmvars for asynchronous
instructions
956a686c60 is described below
commit 956a686c60afdcce95ac6f5398f181dc66e62ba2
Author: Arnab Phani <[email protected]>
AuthorDate: Thu Dec 1 22:52:24 2022 +0100
[SYSTEMDS-3473] Push down rmvars for asynchronous instructions
This patch repositions the rmvar instructions for the inputs to an
asynchronous instruction after the consumers of the asynchronous
instruction. This change allows keeping the inputs to an asynchronous
operator alive until get() is called for the future object.
Moreover, this patch adds more tests for the new operator ordering
and fixes minor bugs.
Closes #1745
---
src/main/java/org/apache/sysds/lops/Lop.java | 14 +++
.../java/org/apache/sysds/lops/compile/Dag.java | 26 +++-
.../lops/compile/linearization/ILinearize.java | 20 ++--
.../instructions/spark/CpmmSPInstruction.java | 133 +++++++++++++++++----
.../instructions/spark/MapmmSPInstruction.java | 50 +++++++-
.../functions/async/MaxParallelizeOrderTest.java | 15 ++-
.../functions/async/MaxParallelizeOrder3.dml | 36 ++++++
.../functions/async/MaxParallelizeOrder4.dml | 37 ++++++
8 files changed, 288 insertions(+), 43 deletions(-)
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java
b/src/main/java/org/apache/sysds/lops/Lop.java
index 3f1cdfe8f6..3064d2b39a 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -150,6 +150,12 @@ public abstract class Lop
protected OutputParameters outParams = null;
protected LopProperties lps = null;
+
+ /**
+ * Indicates if this lop is a candidate for asynchronous execution.
+ * Examples include spark unary aggregate, mapmm, prefetch
+ */
+ protected boolean _asynchronous = false;
/**
@@ -365,6 +371,14 @@ public abstract class Lop
return consumerCount;
}
+ public void setAsynchronous(boolean isAsync) {
+ _asynchronous = isAsync;
+ }
+
+ public boolean isAsynchronousOp() {
+ return _asynchronous;
+ }
+
/**
* Method to have Lops print their state. This is for debugging
purposes.
*/
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 77325fb297..ade809aea6 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -422,16 +422,27 @@ public class Dag<N extends Lop>
* @param delteInst list of instructions
*/
private static void processConsumersForInputs(Lop node,
List<Instruction> inst, List<Instruction> delteInst) {
+ // The asynchronous instructions execute lazily. The inputs to
an asynchronous instruction
+ // must live till the outputs of the async. instruction are
consumed (i.e. future.get is called)
+ if (node.isAsynchronousOp())
+ return;
+
// reduce the consumer count for all input lops
- // if the count becomes zero, then then variable associated w/
input can be removed
+ // if the count becomes zero, then variable associated w/ input
can be removed
for(Lop in : node.getInputs() )
processConsumers(in, inst, delteInst, null);
}
private static void processConsumers(Lop node, List<Instruction> inst,
List<Instruction> deleteInst, Lop locationInfo) {
// reduce the consumer count for all input lops
- // if the count becomes zero, then then variable associated w/
input can be removed
+ // if the count becomes zero, then variable associated w/ input
can be removed
+
if ( node.removeConsumer() == 0 ) {
+ // The inputs to the asynchronous input can be safely
removed at this point as
+ // the outputs of the asynchronous instruction are
consumed.
+ if (node.isAsynchronousOp())
+ processConsumerIfAsync(node, inst, deleteInst);
+
if ( node.isDataExecLocation() &&
((Data)node).isLiteral() ) {
return;
}
@@ -450,6 +461,17 @@ public class Dag<N extends Lop>
excludeRemoveInstruction(label, deleteInst);
}
}
+
+ // Generate rmvar instructions for the inputs of an asynchronous
instruction.
+ private static void processConsumerIfAsync(Lop node, List<Instruction>
inst, List<Instruction> deleteInst) {
+ if (!node.isAsynchronousOp())
+ return;
+
+ // Temporarily disable the _asynchronous flag to generate
rmvars for the inputs
+ node.setAsynchronous(false);
+ processConsumersForInputs(node, inst, deleteInst);
+ node.setAsynchronous(true);
+ }
/**
* Method to generate instructions that are executed in Control
Program. At
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index e78530b33b..7eee970e2b 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -183,7 +183,7 @@ public interface ILinearize {
private static List<Lop> doMaxParallelizeSort(List<Lop> v)
{
List<Lop> final_v = null;
- if (v.stream().anyMatch(ILinearize::isSparkAction)) {
+ if (v.stream().anyMatch(ILinearize::isSparkTriggeringOp)) {
// Step 1: Collect the Spark roots and #Spark
instructions in each subDAG
Map<Long, Integer> sparkOpCount = new HashMap<>();
List<Lop> roots = v.stream().filter(l ->
l.getOutputs().isEmpty()).collect(Collectors.toList());
@@ -221,7 +221,7 @@ public interface ILinearize {
if (sparkOpCount.containsKey(root.getID())) //visited before
return sparkOpCount.get(root.getID());
- // Sum Spark operators in the child DAGs
+ // Aggregate #Spark operators in the child DAGs
int total = 0;
for (Lop input : root.getInputs())
total += collectSparkRoots(input, sparkOpCount,
sparkRoots);
@@ -230,9 +230,11 @@ public interface ILinearize {
total = root.isExecSpark() ? total + 1 : total;
sparkOpCount.put(root.getID(), total);
- // Triggering point: Spark operator with all CP consumers
- if (isSparkAction(root) && root.isAllOutputsCP())
+ // Triggering point: Spark action/operator with all CP consumers
+ if (isSparkTriggeringOp(root)) {
sparkRoots.add(root);
+ root.setAsynchronous(true); //candidate for async.
execution
+ }
return total;
}
@@ -262,11 +264,11 @@ public interface ILinearize {
root.setVisited();
}
- private static boolean isSparkAction(Lop lop) {
+ private static boolean isSparkTriggeringOp(Lop lop) {
return lop.isExecSpark() && (lop.getAggType() ==
SparkAggType.SINGLE_BLOCK
|| lop.getDataType() == DataType.SCALAR || lop
instanceof MapMultChain
|| lop instanceof PickByCount || lop instanceof MMZip
|| lop instanceof CentralMoment
- || lop instanceof CoVariance || lop instanceof MMTSJ);
+ || lop instanceof CoVariance || lop instanceof MMTSJ ||
lop.isAllOutputsCP());
}
private static List<Lop> addPrefetchLop(List<Lop> nodes) {
@@ -276,11 +278,12 @@ public interface ILinearize {
for (Lop l : nodes) {
nodesWithPrefetch.add(l);
if (isPrefetchNeeded(l)) {
- //TODO: No prefetch if the parent is placed
right after the spark OP
- //or push the parent further to increase
parallelism
List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
//Construct a Prefetch lop that takes this
Spark node as a input
UnaryCP prefetch = new UnaryCP(l,
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
+ prefetch.setAsynchronous(true);
+ //Reset asynchronous flag for the input if
already set (e.g. mapmm -> prefetch)
+ l.setAsynchronous(false);
for (Lop outCP : oldOuts) {
//Rewire l -> outCP to l -> Prefetch ->
outCP
prefetch.addOutput(outCP);
@@ -304,6 +307,7 @@ public interface ILinearize {
List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
//Construct a Broadcast lop that takes this
Spark node as an input
UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST,
l.getDataType(), l.getValueType(), ExecType.CP);
+ bc.setAsynchronous(true);
//FIXME: Wire Broadcast only with the necessary
outputs
for (Lop outCP : oldOuts) {
//Rewire l -> outCP to l -> Broadcast
-> outCP
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 0cbd4acfe0..653596806d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -23,6 +23,7 @@ import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
+import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -43,8 +44,13 @@ import
org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.util.CommonThreadPool;
import scala.Tuple2;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
/**
* Cpmm: cross-product matrix multiplication operation (distributed matrix
multiply
* by join over common dimension and subsequent aggregation of partial
results).
@@ -96,19 +102,31 @@ public class CpmmSPInstruction extends
AggregateBinarySPInstruction {
}
if( SparkUtils.isHashPartitioned(in1) //ZIPMM-like CPMM
- && mc1.getNumRowBlocks()==1 && mc2.getCols()==1 ) {
- //note: if the major input is hash-partitioned and it's
a matrix-vector
- //multiply, avoid the index mapping to preserve the
partitioning similar
- //to a ZIPMM but with different transpose
characteristics
- JavaRDD<MatrixBlock> out = in1
- .join(in2.mapToPair(new ReorgMapFunction("r'")))
- .values().map(new Cpmm2MultiplyFunction())
- .filter(new FilterNonEmptyBlocksFunction2());
- MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
-
- //put output block into symbol table (no lineage
because single block)
- //this also includes implicit maintenance of matrix
characteristics
- sec.setMatrixOutput(output.getName(), out2);
+ && mc1.getNumRowBlocks()==1 && mc2.getCols()==1 )
+ //note: if the major input is hash-partitioned and it's a
matrix-vector
+ //multiply, avoid the index mapping to preserve the
partitioning similar
+ //to a ZIPMM but with different transpose characteristics
+ {
+ if (ConfigurationManager.isMaxPrallelizeEnabled()) {
+ try {
+
if(CommonThreadPool.triggerRemoteOPsPool == null)
+
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
+ CpmmMatrixVectorTask task = new
CpmmMatrixVectorTask(in1, in2);
+ Future<MatrixBlock> future_out =
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+ sec.setMatrixOutput(output.getName(),
future_out);
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+ else {
+ JavaRDD<MatrixBlock> out =
in1.join(in2.mapToPair(new ReorgMapFunction("r'"))).values().map(new
Cpmm2MultiplyFunction()).filter(new FilterNonEmptyBlocksFunction2());
+ MatrixBlock out2 =
RDDAggregateUtils.sumStable(out);
+
+ //put output block into symbol table (no
lineage because single block)
+ //this also includes implicit maintenance of
matrix characteristics
+ sec.setMatrixOutput(output.getName(), out2);
+ }
}
else //GENERAL CPMM
{
@@ -119,21 +137,39 @@ public class CpmmSPInstruction extends
AggregateBinarySPInstruction {
//process core cpmm matrix multiply
JavaPairRDD<Long, IndexedMatrixValue> tmp1 =
in1.mapToPair(new CpmmIndexFunction(true));
JavaPairRDD<Long, IndexedMatrixValue> tmp2 =
in2.mapToPair(new CpmmIndexFunction(false));
- JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1
- .join(tmp2, numPartJoin) // join
over common dimension
- .mapToPair(new CpmmMultiplyFunction()); //
compute block multiplications
-
+
//process cpmm aggregation and handle outputs
- if( _aggtype == SparkAggType.SINGLE_BLOCK ) {
- //prune empty blocks and aggregate all results
- out = out.filter(new
FilterNonEmptyBlocksFunction());
- MatrixBlock out2 =
RDDAggregateUtils.sumStable(out);
-
- //put output block into symbol table (no
lineage because single block)
- //this also includes implicit maintenance of
matrix characteristics
- sec.setMatrixOutput(output.getName(), out2);
+ if( _aggtype == SparkAggType.SINGLE_BLOCK )
+ {
+ if
(ConfigurationManager.isMaxPrallelizeEnabled()) {
+ try {
+
if(CommonThreadPool.triggerRemoteOPsPool == null)
+
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
+ CpmmMatrixMatrixTask task = new
CpmmMatrixMatrixTask(in1, in2, numPartJoin);
+ Future<MatrixBlock> future_out
= CommonThreadPool.triggerRemoteOPsPool.submit(task);
+
sec.setMatrixOutput(output.getName(), future_out);
+ }
+ catch(Exception ex) { throw new
DMLRuntimeException(ex); }
+ }
+ else {
+ JavaPairRDD<MatrixIndexes, MatrixBlock>
out = tmp1
+ .join(tmp2, numPartJoin)
// join over common dimension
+ .mapToPair(new
CpmmMultiplyFunction()); // compute block multiplications
+ //prune empty blocks and aggregate all
results
+ out = out.filter(new
FilterNonEmptyBlocksFunction());
+ MatrixBlock out2 =
RDDAggregateUtils.sumStable(out);
+
+ //put output block into symbol table
(no lineage because single block)
+ //this also includes implicit
maintenance of matrix characteristics
+ sec.setMatrixOutput(output.getName(),
out2);
+ }
+
}
- else { //DEFAULT: MULTI_BLOCK
+ else
+ { //DEFAULT: MULTI_BLOCK
+ JavaPairRDD<MatrixIndexes,MatrixBlock> out =
tmp1
+ .join(tmp2, numPartJoin)
// join over common dimension
+ .mapToPair(new CpmmMultiplyFunction());
// compute block multiplications
if( !_outputEmptyBlocks ||
mc1.isNoEmptyBlocks() || mc2.isNoEmptyBlocks() )
out = out.filter(new
FilterNonEmptyBlocksFunction());
out = RDDAggregateUtils.sumByKeyStable(out,
false);
@@ -234,4 +270,49 @@ public class CpmmSPInstruction extends
AggregateBinarySPInstruction {
return OperationsOnMatrixValues.matMult(in1, in2, new
MatrixBlock(), _op);
}
}
+
+ private static class CpmmMatrixVectorTask implements
Callable<MatrixBlock>
+ {
+ JavaPairRDD<MatrixIndexes, MatrixBlock> _in1;
+ JavaPairRDD<MatrixIndexes, MatrixBlock> _in2;
+
+ CpmmMatrixVectorTask(JavaPairRDD<MatrixIndexes, MatrixBlock>
in1, JavaPairRDD<MatrixIndexes, MatrixBlock> in2) {
+ _in1 = in1;
+ _in2 = in2;
+ }
+ @Override
+ public MatrixBlock call() {
+ JavaRDD<MatrixBlock> out = _in1
+ .join(_in2.mapToPair(new
ReorgMapFunction("r'")))
+ .values().map(new Cpmm2MultiplyFunction())
+ .filter(new FilterNonEmptyBlocksFunction2());
+ return RDDAggregateUtils.sumStable(out);
+ }
+ }
+
+ private static class CpmmMatrixMatrixTask implements
Callable<MatrixBlock>
+ {
+ JavaPairRDD<MatrixIndexes, MatrixBlock> _in1;
+ JavaPairRDD<MatrixIndexes, MatrixBlock> _in2;
+ int _numPartJoin;
+
+ CpmmMatrixMatrixTask(JavaPairRDD<MatrixIndexes, MatrixBlock>
in1, JavaPairRDD<MatrixIndexes, MatrixBlock> in2, int nPartJoin) {
+ _in1 = in1;
+ _in2 = in2;
+ _numPartJoin = nPartJoin;
+ }
+ @Override
+ public MatrixBlock call() {
+ //process core cpmm matrix multiply
+ JavaPairRDD<Long, IndexedMatrixValue> tmp1 =
_in1.mapToPair(new CpmmIndexFunction(true));
+ JavaPairRDD<Long, IndexedMatrixValue> tmp2 =
_in2.mapToPair(new CpmmIndexFunction(false));
+ JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1
+ .join(tmp2, _numPartJoin) //
join over common dimension
+ .mapToPair(new CpmmMultiplyFunction()); //
compute block multiplications
+
+ //prune empty blocks and aggregate all results
+ out = out.filter(new FilterNonEmptyBlocksFunction());
+ return RDDAggregateUtils.sumStable(out);
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index 3a1a6c27d9..29f28b604e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -21,6 +21,9 @@ package org.apache.sysds.runtime.instructions.spark;
import java.util.Iterator;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
import java.util.stream.IntStream;
import org.apache.commons.logging.Log;
@@ -30,6 +33,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
+import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.MapMult;
@@ -54,6 +58,7 @@ import
org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.util.CommonThreadPool;
import scala.Tuple2;
public class MapmmSPInstruction extends AggregateBinarySPInstruction {
@@ -135,12 +140,24 @@ public class MapmmSPInstruction extends
AggregateBinarySPInstruction {
//execute mapmm and aggregation if necessary and put output
into symbol table
if( _aggtype == SparkAggType.SINGLE_BLOCK )
{
- JavaRDD<MatrixBlock> out = in1.map(new
RDDMapMMFunction2(type, in2));
- MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
-
- //put output block into symbol table (no lineage
because single block)
- //this also includes implicit maintenance of matrix
characteristics
- sec.setMatrixOutput(output.getName(), out2);
+ if (ConfigurationManager.isMaxPrallelizeEnabled()) {
+ try {
+
if(CommonThreadPool.triggerRemoteOPsPool == null)
+
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
+ RDDMapmmTask task = new
RDDMapmmTask(in1, in2, type);
+ Future<MatrixBlock> future_out =
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+ sec.setMatrixOutput(output.getName(),
future_out);
+ }
+ catch(Exception ex) { throw new
DMLRuntimeException(ex); }
+ }
+ else {
+ JavaRDD<MatrixBlock> out = in1.map(new
RDDMapMMFunction2(type, in2));
+ MatrixBlock out2 =
RDDAggregateUtils.sumStable(out);
+
+ //put output block into symbol table (no
lineage because single block)
+ //this also includes implicit maintenance of
matrix characteristics
+ sec.setMatrixOutput(output.getName(), out2);
+ }
}
else //MULTI_BLOCK or NONE
{
@@ -443,4 +460,25 @@ public class MapmmSPInstruction extends
AggregateBinarySPInstruction {
}
}
}
+
+ private static class RDDMapmmTask implements Callable<MatrixBlock>
+ {
+ JavaPairRDD<MatrixIndexes, MatrixBlock> _in1;
+ PartitionedBroadcast<MatrixBlock> _in2;
+ MapMult.CacheType _type;
+
+ RDDMapmmTask(JavaPairRDD<MatrixIndexes, MatrixBlock> in1,
PartitionedBroadcast<MatrixBlock> in2, MapMult.CacheType type) {
+ _in1 = in1;
+ _in2 = in2;
+ _type = type;
+ }
+
+ @Override
+ public MatrixBlock call() {
+ //execute mapmm and aggregation if necessary and put
output into symbol table
+ JavaRDD<MatrixBlock> out = _in1.map(new
RDDMapMMFunction2(_type, _in2));
+ MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
+ return out2;
+ }
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
index be011925d2..ee89824c64 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
@@ -37,7 +37,7 @@ public class MaxParallelizeOrderTest extends
AutomatedTestBase {
protected static final String TEST_DIR = "functions/async/";
protected static final String TEST_NAME = "MaxParallelizeOrder";
- protected static final int TEST_VARIANTS = 2;
+ protected static final int TEST_VARIANTS = 4;
protected static String TEST_CLASS_DIR = TEST_DIR +
MaxParallelizeOrderTest.class.getSimpleName() + "/";
@Override
@@ -57,6 +57,16 @@ public class MaxParallelizeOrderTest extends
AutomatedTestBase {
runTest(TEST_NAME+"2");
}
+ @Test
+ public void testSparkAction() {
+ runTest(TEST_NAME+"3");
+ }
+
+ @Test
+ public void testSparkTransformations() {
+ runTest(TEST_NAME+"4");
+ }
+
public void runTest(String testname) {
boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
boolean old_sum_product =
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -85,10 +95,13 @@ public class MaxParallelizeOrderTest extends
AutomatedTestBase {
OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
+ if (testname.equalsIgnoreCase(TEST_NAME+"4"))
+ OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE
= false;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R_mp =
readDMLScalarFromOutputDir("R");
OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
+ OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
//compare matrices
boolean matchVal = TestUtils.compareMatrices(R, R_mp,
1e-6, "Origin", "withPrefetch");
diff --git a/src/test/scripts/functions/async/MaxParallelizeOrder3.dml
b/src/test/scripts/functions/async/MaxParallelizeOrder3.dml
new file mode 100644
index 0000000000..c26dee67c5
--- /dev/null
+++ b/src/test/scripts/functions/async/MaxParallelizeOrder3.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+v = rand(rows=200, cols=1, seed=42); #cp_rand
+
+# CP instructions
+v = ((v + v) * 1 - v) / (1+1);
+v = ((v + v) * 2 - v) / (2+1);
+
+# Spark transformation operations
+sp1 = X + ceil(X);
+sp2 = sp1 %*% v; #output fits in local
+
+# CP binary triggers the DAG of SP operations
+# if transitive spark exec type is off
+cp = sp2 + sum(v);
+R = sum(cp);
+write(R, $1, format="text");
diff --git a/src/test/scripts/functions/async/MaxParallelizeOrder4.dml
b/src/test/scripts/functions/async/MaxParallelizeOrder4.dml
new file mode 100644
index 0000000000..2a7f7001ac
--- /dev/null
+++ b/src/test/scripts/functions/async/MaxParallelizeOrder4.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+v = rand(rows=200, cols=1, seed=42); #cp_rand
+v2 = rand(rows=200, cols=1, seed=43); #cp_rand
+
+# CP instructions
+v = ((v + v) * 1 - v) / (1+1);
+v = ((v + v) * 2 - v) / (2+1);
+
+# Spark transformation operations
+sp1 = X + ceil(X);
+sp2 = sp1 %*% v2; #output fits in local
+
+# CP binary triggers the DAG of SP operations
+# if transitive spark exec type is off
+cp = sp2 + sum(v);
+R = sum(cp);
+write(R, $1, format="text");