This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new bc0c19dcc8 [SYSTEMDS-3469] New operator ordering to maximize inter-op
parallelism
bc0c19dcc8 is described below
commit bc0c19dcc8776e4fc3ac36bfb2dc6c5394541a6f
Author: Arnab Phani <[email protected]>
AuthorDate: Fri Nov 25 17:46:31 2022 +0100
[SYSTEMDS-3469] New operator ordering to maximize inter-op parallelism
This patch introduces a new heuristic-based operator linearization order,
which aims to maximize inter-operator parallelism among Spark and local
operators. We first traverse the LOP DAGs to collect the roots of the Spark
operator chains and the number of Spark instructions in all subDAGs. We
then first place the Spark operator chains followed by the CP lanes.
Finally, we place the appropriate asynchronous operators to trigger the
Spark operator chains in parallel.
This change along with the future-based execution of Spark actions and
a manual reuse of partitioned broadcast variables improve lmDS by 2x.
Closes #1736
---
.../apache/sysds/conf/ConfigurationManager.java | 15 ++
.../java/org/apache/sysds/hops/OptimizerUtils.java | 6 +
src/main/java/org/apache/sysds/lops/Lop.java | 18 +-
.../java/org/apache/sysds/lops/compile/Dag.java | 13 +-
.../lops/compile/linearization/ILinearize.java | 209 ++++++++++++++++++++-
.../context/SparkExecutionContext.java | 10 +-
.../spark/AggregateUnarySPInstruction.java | 2 +-
.../instructions/spark/TsmmSPInstruction.java | 53 +++++-
.../test/functions/async/AsyncBroadcastTest.java | 2 +
...dcastTest.java => MaxParallelizeOrderTest.java} | 55 +++---
.../test/functions/async/PrefetchRDDTest.java | 5 +-
.../linearization/DagLinearizationTest.java | 2 +-
.../functions/async/MaxParallelizeOrder1.dml | 52 +++++
.../functions/async/MaxParallelizeOrder2.dml | 69 +++++++
...ect.xml => SystemDS-config-max-parallelize.xml} | 2 +-
15 files changed, 447 insertions(+), 66 deletions(-)
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 930d26a6d0..bb6172993a 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -23,6 +23,7 @@ import org.apache.hadoop.mapred.JobConf;
import org.apache.sysds.conf.CompilerConfig.ConfigType;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Compression.CompressConfig;
+import org.apache.sysds.lops.compile.linearization.ILinearize;
/**
* Singleton for accessing the parsed and merged system configuration.
@@ -237,11 +238,25 @@ public class ConfigurationManager
|| OptimizerUtils.ASYNC_PREFETCH_SPARK);
}
+ public static boolean isMaxPrallelizeEnabled() {
+ return (getLinearizationOrder() ==
ILinearize.DagLinearization.MAX_PARALLELIZE
+ || OptimizerUtils.MAX_PARALLELIZE_ORDER);
+ }
+
public static boolean isBroadcastEnabled() {
return
(getDMLConfig().getBooleanValue(DMLConfig.ASYNC_SPARK_BROADCAST)
|| OptimizerUtils.ASYNC_BROADCAST_SPARK);
}
+ public static ILinearize.DagLinearization getLinearizationOrder() {
+ if (OptimizerUtils.MAX_PARALLELIZE_ORDER)
+ return ILinearize.DagLinearization.MAX_PARALLELIZE;
+ else
+ return ILinearize.DagLinearization
+
.valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase());
+
+ }
+
///////////////////////////////////////
// Thread-local classes
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index ccee9c96df..d2e9670362 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -285,6 +285,12 @@ public class OptimizerUtils
public static boolean ASYNC_PREFETCH_SPARK = false;
public static boolean ASYNC_BROADCAST_SPARK = false;
+ /**
+ * Heuristic-based instruction ordering to maximize inter-operator
parallelism.
+ * Place the Spark operator chains first and trigger them to execute in
parallel.
+ */
+ public static boolean MAX_PARALLELIZE_ORDER = false;
+
//////////////////////
// Optimizer levels //
//////////////////////
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java
b/src/main/java/org/apache/sysds/lops/Lop.java
index 440669d13a..3f1cdfe8f6 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -188,6 +188,14 @@ public abstract class Lop
_visited = visited;
}
+ public void setVisited() {
+ setVisited(VisitStatus.DONE);
+ }
+
+ public boolean isVisited() {
+ return _visited == VisitStatus.DONE;
+ }
+
public boolean[] getReachable() {
return reachable;
@@ -297,6 +305,10 @@ public abstract class Lop
}
}
+ public void removeInput(Lop op) {
+ inputs.remove(op);
+ }
+
/**
* Method to add output to Lop
*
@@ -414,7 +426,11 @@ public abstract class Lop
public void setExecType(ExecType newExecType){
lps.setExecType(newExecType);
}
-
+
+ public boolean isExecSpark () {
+ return (lps.getExecType() == ExecType.SPARK);
+ }
+
public boolean getProducesIntermediateOutput() {
return lps.getProducesIntermediateOutput();
}
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index f87163eee3..2efbea8221 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -74,7 +74,6 @@ import
org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-
/**
*
* Class to maintain a DAG of lops and compile it into
@@ -193,17 +192,11 @@ public class Dag<N extends Lop>
}
List<Lop> node_v = ILinearize.linearize(nodes);
-
- // add Prefetch and broadcast lops, if necessary
- List<Lop> node_pf = ConfigurationManager.isPrefetchEnabled() ?
addPrefetchLop(node_v) : node_v;
- List<Lop> node_bc = ConfigurationManager.isBroadcastEnabled() ?
addBroadcastLop(node_pf) : node_pf;
- // TODO: Merge via a single traversal of the nodes
-
- prefetchFederated(node_bc);
+ prefetchFederated(node_v);
// do greedy grouping of operations
- ArrayList<Instruction> inst = doPlainInstructionGen(sb,
node_bc);
-
+ ArrayList<Instruction> inst = doPlainInstructionGen(sb, node_v);
+
// cleanup instruction (e.g., create packed rmvar instructions)
return cleanupInstructions(inst);
}
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index f55271d530..d867a91f4a 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -21,8 +21,10 @@ package org.apache.sysds.lops.compile.linearization;
import java.util.AbstractMap;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
+import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@@ -31,30 +33,50 @@ import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.CSVReBlock;
+import org.apache.sysds.lops.CentralMoment;
+import org.apache.sysds.lops.Checkpoint;
+import org.apache.sysds.lops.CoVariance;
+import org.apache.sysds.lops.GroupedAggregate;
+import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.MMTSJ;
+import org.apache.sysds.lops.MMZip;
+import org.apache.sysds.lops.MapMultChain;
+import org.apache.sysds.lops.ParameterizedBuiltin;
+import org.apache.sysds.lops.PickByCount;
+import org.apache.sysds.lops.ReBlock;
+import org.apache.sysds.lops.SpoofFused;
+import org.apache.sysds.lops.UAggOuterChain;
+import org.apache.sysds.lops.UnaryCP;
/**
* A interface for the linearization algorithms that order the DAG nodes into
a sequence of instructions to execute.
- *
+ *
* https://en.wikipedia.org/wiki/Linearizability#Linearization_points
*/
public interface ILinearize {
public static Log LOG = LogFactory.getLog(ILinearize.class.getName());
public enum DagLinearization {
- DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE
+ DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE, MAX_PARALLELIZE
}
public static List<Lop> linearize(List<Lop> v) {
try {
DMLConfig dmlConfig =
ConfigurationManager.getDMLConfig();
- DagLinearization linearization = DagLinearization
-
.valueOf(dmlConfig.getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase());
+ DagLinearization linearization =
ConfigurationManager.getLinearizationOrder();
switch(linearization) {
+ case MAX_PARALLELIZE:
+ return doMaxParallelizeSort(v);
case MIN_INTERMEDIATE:
return doMinIntermediateSort(v);
case BREADTH_FIRST:
@@ -65,7 +87,7 @@ public interface ILinearize {
}
}
catch(Exception e) {
- LOG.warn("Invalid or failed DAG_LINEARIZATION, fallback
to DEPTH_FIRST ordering");
+ LOG.warn("Invalid DAG_LINEARIZATION
"+ConfigurationManager.getLinearizationOrder()+", fallback to DEPTH_FIRST
ordering");
return depthFirst(v);
}
}
@@ -155,4 +177,181 @@ public interface ILinearize {
sortRecursive(result, e.getKey().getInputs(),
remaining);
}
}
+
+ // Place the Spark operation chains first (more expensive to less
expensive),
+ // followed by asynchronously triggering operators and CP chains.
+ private static List<Lop> doMaxParallelizeSort(List<Lop> v)
+ {
+ List<Lop> final_v = null;
+ if (v.stream().anyMatch(ILinearize::isSparkAction)) {
+ // Step 1: Collect the Spark roots and #Spark
instructions in each subDAG
+ Map<Long, Integer> sparkOpCount = new HashMap<>();
+ List<Lop> roots = v.stream().filter(l ->
l.getOutputs().isEmpty()).collect(Collectors.toList());
+ List<Lop> sparkRoots = new ArrayList<>();
+ roots.forEach(r -> collectSparkRoots(r, sparkOpCount,
sparkRoots));
+
+ // Step 2: Depth-first linearization. Place the Spark
OPs first.
+ // Sort the Spark roots based on number of Spark
operators descending
+ ArrayList<Lop> operatorList = new ArrayList<>();
+ Lop[] sortedSPRoots = sparkRoots.toArray(new Lop[0]);
+ Arrays.sort(sortedSPRoots, (l1, l2) ->
sparkOpCount.get(l2.getID()) - sparkOpCount.get(l1.getID()));
+ Arrays.stream(sortedSPRoots).forEach(r -> depthFirst(r,
operatorList, sparkOpCount, true));
+
+ // Step 3: Place the rest of the operators (CP). Sort
the CP roots based on
+ // #Spark operators in ascending order, i.e. execute
the independent CP chains first
+ roots.forEach(r -> depthFirst(r, operatorList,
sparkOpCount, false));
+ roots.forEach(Lop::resetVisitStatus);
+ final_v = operatorList;
+ }
+ else
+ // Fall back to depth if none of the operators returns
results back to local
+ final_v = depthFirst(v);
+
+ // Step 4: Add Prefetch and Broadcast lops if necessary
+ List<Lop> v_pf = ConfigurationManager.isPrefetchEnabled() ?
addPrefetchLop(final_v) : final_v;
+ List<Lop> v_bc = ConfigurationManager.isBroadcastEnabled() ?
addBroadcastLop(v_pf) : v_pf;
+ // TODO: Merge into a single traversal
+
+ return v_bc;
+ }
+
+ // Gather the Spark operators which return intermediates to local
(actions/single_block)
+ // In addition count the number of Spark OPs underneath every Operator
+ private static int collectSparkRoots(Lop root, Map<Long, Integer>
sparkOpCount, List<Lop> sparkRoots) {
+ if (sparkOpCount.containsKey(root.getID())) //visited before
+ return sparkOpCount.get(root.getID());
+
+ // Sum Spark operators in the child DAGs
+ int total = 0;
+ for (Lop input : root.getInputs())
+ total += collectSparkRoots(input, sparkOpCount,
sparkRoots);
+
+ // Check if this node is Spark
+ total = root.isExecSpark() ? total + 1 : total;
+ sparkOpCount.put(root.getID(), total);
+
+ // Triggering point: Spark operator with all CP consumers
+ if (isSparkAction(root) && root.isAllOutputsCP())
+ sparkRoots.add(root);
+
+ return total;
+ }
+
+ // Place the operators in a depth-first manner, but order
+ // the DAGs based on number of Spark operators
+ private static void depthFirst(Lop root, ArrayList<Lop> opList,
Map<Long, Integer> sparkOpCount, boolean sparkFirst) {
+ if (root.isVisited())
+ return;
+
+ if (root.getInputs().isEmpty()) { //leaf node
+ opList.add(root);
+ root.setVisited();
+ return;
+ }
+ // Sort the inputs based on number of Spark operators
+ Lop[] sortedInputs = root.getInputs().toArray(new Lop[0]);
+ if (sparkFirst) //to place the child DAG with more Spark OPs
first
+ Arrays.sort(sortedInputs, (l1, l2) ->
sparkOpCount.get(l2.getID()) - sparkOpCount.get(l1.getID()));
+ else //to place the child DAG with more CP OPs first
+ Arrays.sort(sortedInputs, Comparator.comparingInt(l ->
sparkOpCount.get(l.getID())));
+
+ for (Lop input : sortedInputs)
+ depthFirst(input, opList, sparkOpCount, sparkFirst);
+
+ opList.add(root);
+ root.setVisited();
+ }
+
+ private static boolean isSparkAction(Lop lop) {
+ return lop.isExecSpark() && (lop.getAggType() ==
SparkAggType.SINGLE_BLOCK
+ || lop.getDataType() == DataType.SCALAR || lop
instanceof MapMultChain
+ || lop instanceof PickByCount || lop instanceof MMZip
|| lop instanceof CentralMoment
+ || lop instanceof CoVariance || lop instanceof MMTSJ);
+ }
+
+ private static List<Lop> addPrefetchLop(List<Lop> nodes) {
+ List<Lop> nodesWithPrefetch = new ArrayList<>();
+
+ //Find the Spark nodes with all CP outputs
+ for (Lop l : nodes) {
+ nodesWithPrefetch.add(l);
+ if (isPrefetchNeeded(l)) {
+ //TODO: No prefetch if the parent is placed
right after the spark OP
+ //or push the parent further to increase
parallelism
+ List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
+ //Construct a Prefetch lop that takes this
Spark node as a input
+ UnaryCP prefetch = new UnaryCP(l,
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
+ for (Lop outCP : oldOuts) {
+ //Rewire l -> outCP to l -> Prefetch ->
outCP
+ prefetch.addOutput(outCP);
+ outCP.replaceInput(l, prefetch);
+ l.removeOutput(outCP);
+ //FIXME: Rewire _inputParams when
needed (e.g. GroupedAggregate)
+ }
+ //Place it immediately after the Spark lop in
the node list
+ nodesWithPrefetch.add(prefetch);
+ }
+ }
+ return nodesWithPrefetch;
+ }
+
+ private static List<Lop> addBroadcastLop(List<Lop> nodes) {
+ List<Lop> nodesWithBroadcast = new ArrayList<>();
+
+ for (Lop l : nodes) {
+ nodesWithBroadcast.add(l);
+ if (isBroadcastNeeded(l)) {
+ List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
+ //Construct a Broadcast lop that takes this
Spark node as an input
+ UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST,
l.getDataType(), l.getValueType(), ExecType.CP);
+ //FIXME: Wire Broadcast only with the necessary
outputs
+ for (Lop outCP : oldOuts) {
+ //Rewire l -> outCP to l -> Broadcast
-> outCP
+ bc.addOutput(outCP);
+ outCP.replaceInput(l, bc);
+ l.removeOutput(outCP);
+ //FIXME: Rewire _inputParams when
needed (e.g. GroupedAggregate)
+ }
+ //Place it immediately after the Spark lop in
the node list
+ nodesWithBroadcast.add(bc);
+ }
+ }
+ return nodesWithBroadcast;
+ }
+
+ private static boolean isPrefetchNeeded(Lop lop) {
+ // Run Prefetch for a Spark instruction if the instruction is a
Transformation
+ // and the output is consumed by only CP instructions.
+ boolean transformOP = lop.getExecType() == ExecType.SPARK &&
lop.getAggType() != SparkAggType.SINGLE_BLOCK
+ // Always Action operations
+ && !(lop.getDataType() == DataType.SCALAR)
+ && !(lop instanceof MapMultChain) && !(lop
instanceof PickByCount)
+ && !(lop instanceof MMZip) && !(lop instanceof
CentralMoment)
+ && !(lop instanceof CoVariance)
+ // Not qualified for prefetching
+ && !(lop instanceof Checkpoint) && !(lop
instanceof ReBlock)
+ && !(lop instanceof CSVReBlock)
+ // Cannot filter Transformation cases from
Actions (FIXME)
+ && !(lop instanceof MMTSJ) && !(lop instanceof
UAggOuterChain)
+ && !(lop instanceof ParameterizedBuiltin) &&
!(lop instanceof SpoofFused);
+
+ //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
+ boolean hasParameterizedOut = lop.getOutputs().stream()
+ .anyMatch(out -> ((out instanceof
ParameterizedBuiltin)
+ || (out instanceof GroupedAggregate)
+ || (out instanceof GroupedAggregateM)));
+ //TODO: support non-matrix outputs
+ return transformOP && !hasParameterizedOut
+ && lop.isAllOutputsCP() && lop.getDataType() ==
DataType.MATRIX;
+ }
+
+ private static boolean isBroadcastNeeded(Lop lop) {
+ // Asynchronously broadcast a matrix if that is produced by a
CP instruction,
+ // and at least one Spark parent needs to broadcast this
intermediate (eg. mapmm)
+ boolean isBc = lop.getOutputs().stream()
+ .anyMatch(out -> (out.getBroadcastInput() ==
lop));
+ //TODO: Early broadcast objects that are bigger than a single
block
+ //return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
+ return isBc && lop.getDataType() == DataType.MATRIX;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 388eb462f9..df8f84c6b9 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -672,11 +672,11 @@ public class SparkExecutionContext extends
ExecutionContext
//the broadcasts are created (other than in local mode) in
order to avoid
//unnecessary memory requirements during the lifetime of this
broadcast handle.
- long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
PartitionedBroadcast<MatrixBlock> bret = null;
synchronized (mo) { //synchronize with the async. broadcast
thread
+ long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
//reuse existing broadcast handle
if (mo.getBroadcastHandle() != null &&
mo.getBroadcastHandle().isPartitionedBroadcastValid()) {
bret =
mo.getBroadcastHandle().getPartitionedBroadcast();
@@ -719,10 +719,10 @@ public class SparkExecutionContext extends
ExecutionContext
OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getDataCharacteristics()));
CacheableData.addBroadcastSize(mo.getBroadcastHandle().getSize());
- if (DMLScript.STATISTICS) {
-
SparkStatistics.accBroadCastTime(System.nanoTime() - t0);
- SparkStatistics.incBroadcastCount(1);
- }
+ }
+ if (DMLScript.STATISTICS) {
+
SparkStatistics.accBroadCastTime(System.nanoTime() - t0);
+ SparkStatistics.incBroadcastCount(1);
}
}
return bret;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 52bab3958f..89f385c54e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -109,7 +109,7 @@ public class AggregateUnarySPInstruction extends
UnarySPInstruction {
//perform aggregation if necessary and put output into symbol
table
if( _aggtype == SparkAggType.SINGLE_BLOCK )
{
- if (ConfigurationManager.isPrefetchEnabled()) {
+ if (ConfigurationManager.isMaxPrallelizeEnabled()) {
//Trigger the chain of Spark operations and
maintain a future to the result
//TODO: Make memory for the future matrix block
try {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
index 69db3787e5..17cef61158 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
@@ -23,6 +23,7 @@ package org.apache.sysds.runtime.instructions.spark;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
+import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -33,8 +34,13 @@ import
org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
import scala.Tuple2;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
public class TsmmSPInstruction extends UnarySPInstruction {
private MMTSJType _type = null;
@@ -61,15 +67,29 @@ public class TsmmSPInstruction extends UnarySPInstruction {
//get input
JavaPairRDD<MatrixIndexes,MatrixBlock> in =
sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
-
- //execute tsmm instruction (always produce exactly one output
block)
- //(this formulation with values() requires --conf
spark.driver.maxResultSize=0)
- JavaRDD<MatrixBlock> tmp = in.map(new RDDTSMMFunction(_type));
- MatrixBlock out = RDDAggregateUtils.sumStable(tmp);
- //put output block into symbol table (no lineage because single
block)
- //this also includes implicit maintenance of matrix
characteristics
- sec.setMatrixOutput(output.getName(), out);
+ if (ConfigurationManager.isMaxPrallelizeEnabled()) {
+ try {
+ if (CommonThreadPool.triggerRemoteOPsPool ==
null)
+ CommonThreadPool.triggerRemoteOPsPool =
Executors.newCachedThreadPool();
+ TsmmTask task = new TsmmTask(in, _type);
+ Future<MatrixBlock> future_out =
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+ sec.setMatrixOutput(output.getName(),
future_out);
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+ else {
+ //execute tsmm instruction (always produce exactly one
output block)
+ //(this formulation with values() requires --conf
spark.driver.maxResultSize=0)
+ JavaRDD<MatrixBlock> tmp = in.map(new
RDDTSMMFunction(_type));
+ MatrixBlock out = RDDAggregateUtils.sumStable(tmp);
+
+ //put output block into symbol table (no lineage
because single block)
+ //this also includes implicit maintenance of matrix
characteristics
+ sec.setMatrixOutput(output.getName(), out);
+ }
}
private static class RDDTSMMFunction implements
Function<Tuple2<MatrixIndexes,MatrixBlock>, MatrixBlock>
@@ -90,5 +110,22 @@ public class TsmmSPInstruction extends UnarySPInstruction {
return arg0._2().transposeSelfMatrixMultOperations(new
MatrixBlock(), _type);
}
}
+
+ private static class TsmmTask implements Callable<MatrixBlock> {
+ JavaPairRDD<MatrixIndexes, MatrixBlock> _in;
+ MMTSJType _type;
+
+ TsmmTask(JavaPairRDD<MatrixIndexes, MatrixBlock> in, MMTSJType
type) {
+ _in = in;
+ _type = type;
+ }
+ @Override
+ public MatrixBlock call() {
+ //execute tsmm instruction (always produce exactly one
output block)
+ //(this formulation with values() requires --conf
spark.driver.maxResultSize=0)
+ JavaRDD<MatrixBlock> tmp = _in.map(new
RDDTSMMFunction(_type));
+ return RDDAggregateUtils.sumStable(tmp);
+ }
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
index edfef8998d..c8b7fdd94f 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
@@ -88,9 +88,11 @@ public class AsyncBroadcastTest extends AutomatedTestBase {
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
+ OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
OptimizerUtils.ASYNC_BROADCAST_SPARK = true;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
OptimizerUtils.ASYNC_BROADCAST_SPARK = false;
+ OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
HashMap<MatrixValue.CellIndex, Double> R_bc =
readDMLScalarFromOutputDir("R");
//compare matrices
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
similarity index 69%
copy from
src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
copy to
src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
index edfef8998d..be011925d2 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
@@ -31,32 +31,29 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.apache.sysds.utils.Statistics;
-import org.apache.sysds.utils.stats.SparkStatistics;
-import org.junit.Assert;
import org.junit.Test;
-public class AsyncBroadcastTest extends AutomatedTestBase {
-
+public class MaxParallelizeOrderTest extends AutomatedTestBase {
+
protected static final String TEST_DIR = "functions/async/";
- protected static final String TEST_NAME = "BroadcastVar";
+ protected static final String TEST_NAME = "MaxParallelizeOrder";
protected static final int TEST_VARIANTS = 2;
- protected static String TEST_CLASS_DIR = TEST_DIR +
AsyncBroadcastTest.class.getSimpleName() + "/";
-
+ protected static String TEST_CLASS_DIR = TEST_DIR +
MaxParallelizeOrderTest.class.getSimpleName() + "/";
+
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- for( int i=1; i<=TEST_VARIANTS; i++ )
+ for(int i=1; i<=TEST_VARIANTS; i++)
addTestConfiguration(TEST_NAME+i, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i));
}
-
+
@Test
- public void testAsyncBroadcast1() {
+ public void testlmds() {
runTest(TEST_NAME+"1");
}
@Test
- public void testAsyncBroadcast2() {
+ public void testl2svm() {
runTest(TEST_NAME+"2");
}
@@ -65,21 +62,19 @@ public class AsyncBroadcastTest extends AutomatedTestBase {
boolean old_sum_product =
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
boolean old_trans_exec_type =
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE;
ExecMode oldPlatform = setExecMode(ExecMode.HYBRID);
-
+
long oldmem = InfrastructureAnalyzer.getLocalMaxMemory();
long mem = 1024*1024*8;
InfrastructureAnalyzer.setLocalMaxMemory(mem);
-
+
try {
- //OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false;
- //OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false;
- OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
getAndLoadTestConfiguration(testname);
fullDMLScriptName = getScript();
-
+
List<String> proArgs = new ArrayList<>();
-
- //proArgs.add("-explain");
+
+ proArgs.add("-explain");
+ //proArgs.add("recompile_runtime");
proArgs.add("-stats");
proArgs.add("-args");
proArgs.add(output("R"));
@@ -88,21 +83,17 @@ public class AsyncBroadcastTest extends AutomatedTestBase {
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
- OptimizerUtils.ASYNC_BROADCAST_SPARK = true;
+ OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
+ OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
- OptimizerUtils.ASYNC_BROADCAST_SPARK = false;
- HashMap<MatrixValue.CellIndex, Double> R_bc =
readDMLScalarFromOutputDir("R");
+ HashMap<MatrixValue.CellIndex, Double> R_mp =
readDMLScalarFromOutputDir("R");
+ OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+ OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
//compare matrices
- TestUtils.compareMatrices(R, R_bc, 1e-6, "Origin",
"withBroadcast");
-
- //assert called and successful early broadcast counts
- long expected_numBC = 1;
- long expected_successBC = 1;
- long numBC =
Statistics.getCPHeavyHitterCount("broadcast");
- Assert.assertTrue("Violated Broadcast instruction
count: "+numBC, numBC == expected_numBC);
- long successBC =
SparkStatistics.getAsyncBroadcastCount();
- Assert.assertTrue("Violated successful Broadcast count:
"+successBC, successBC == expected_successBC);
+ boolean matchVal = TestUtils.compareMatrices(R, R_mp,
1e-6, "Origin", "withPrefetch");
+ if (!matchVal)
+ System.out.println("Value w/o Prefetch "+R+" w/
Prefetch "+R_mp);
} finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
old_simplification;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES =
old_sum_product;
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index a2fa45c2f4..11af05f19d 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -32,7 +32,6 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;
-import org.apache.sysds.utils.stats.SparkStatistics;
import org.junit.Assert;
import org.junit.Test;
@@ -96,13 +95,15 @@ public class PrefetchRDDTest extends AutomatedTestBase {
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
+ OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+ OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
HashMap<MatrixValue.CellIndex, Double> R_pf =
readDMLScalarFromOutputDir("R");
//compare matrices
- Boolean matchVal = TestUtils.compareMatrices(R, R_pf,
1e-6, "Origin", "withPrefetch");
+ boolean matchVal = TestUtils.compareMatrices(R, R_pf,
1e-6, "Origin", "withPrefetch");
if (!matchVal)
System.out.println("Value w/o Prefetch "+R+" w/
Prefetch "+R_pf);
//assert Prefetch instructions and number of success.
diff --git
a/src/test/java/org/apache/sysds/test/functions/linearization/DagLinearizationTest.java
b/src/test/java/org/apache/sysds/test/functions/linearization/DagLinearizationTest.java
index dc84c75e61..843faa36b9 100644
---
a/src/test/java/org/apache/sysds/test/functions/linearization/DagLinearizationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/linearization/DagLinearizationTest.java
@@ -39,7 +39,7 @@ public class DagLinearizationTest extends AutomatedTestBase {
private final String testNames[] = {"matrixmult_dag_linearization",
"csplineCG_dag_linearization",
"linear_regression_dag_linearization"};
- private final String testConfigs[] = {"breadth-first", "depth-first",
"incorrect", "min-intermediate"};
+ private final String testConfigs[] = {"breadth-first", "depth-first",
"min-intermediate", "max-parallelize"};
private final String testDir = "functions/linearization/";
diff --git a/src/test/scripts/functions/async/MaxParallelizeOrder1.dml
b/src/test/scripts/functions/async/MaxParallelizeOrder1.dml
new file mode 100644
index 0000000000..218b09cb6e
--- /dev/null
+++ b/src/test/scripts/functions/async/MaxParallelizeOrder1.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y, Double lamda,
Integer N) return (Matrix[double] beta)
+{
+ A = (t(X) %*% X) + diag(matrix(lamda, rows=N, cols=1));
+ b = t(X) %*% y;
+ beta = solve(A, b);
+}
+
+no_lamda = 10;
+
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+lim = 0.1;
+
+X = rand(rows=10000, cols=200, seed=42);
+y = rand(rows=10000, cols=1, seed=43);
+N = ncol(X);
+R = matrix(0, rows=N, cols=no_lamda+2);
+i = 1;
+
+while (lamda < lim)
+{
+ beta = SimlinRegDS(X, y, lamda, N);
+ #beta = lmDS(X=X, y=y, reg=lamda);
+ R[,i] = beta;
+ lamda = lamda + stp;
+ i = i + 1;
+}
+
+R = sum(R);
+write(R, $1, format="text");
+
diff --git a/src/test/scripts/functions/async/MaxParallelizeOrder2.dml
b/src/test/scripts/functions/async/MaxParallelizeOrder2.dml
new file mode 100644
index 0000000000..81e9207105
--- /dev/null
+++ b/src/test/scripts/functions/async/MaxParallelizeOrder2.dml
@@ -0,0 +1,69 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B,
Boolean icpt)
+return (Matrix[Double] loss) {
+ if (icpt)
+ X = cbind(X, matrix(1, nrow(X), 1));
+ loss = as.matrix(sum((y - X%*%B)^2));
+}
+
+M = 100000;
+N = 20;
+sp = 1.0;
+no_lamda = 1;
+
+X = rand(rows=M, cols=N, sparsity=sp, seed=42);
+y = rand(rows=M, cols=1, min=0, max=2, seed=42);
+y = ceil(y);
+
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+Rbeta = matrix(0, rows=ncol(X)+1, cols=no_lamda*2);
+Rloss = matrix(0, rows=no_lamda*2, cols=1);
+i = 1;
+
+
+for (l in 1:no_lamda)
+{
+ beta = l2svm(X=X, Y=y, intercept=FALSE, epsilon=1e-12,
+# lambda = lamda, maxIterations=10, verbose=FALSE);
+ reg = lamda, verbose=FALSE);
+ Rbeta[1:nrow(beta),i] = beta;
+ Rloss[i,] = l2norm(X, y, beta, FALSE);
+ i = i + 1;
+
+ beta = l2svm(X=X, Y=y, intercept=TRUE, epsilon=1e-12,
+# lambda = lamda, maxIterations=10, verbose=FALSE);
+ reg = lamda, verbose=FALSE);
+ Rbeta[1:nrow(beta),i] = beta;
+ Rloss[i,] = l2norm(X, y, beta, TRUE);
+ i = i + 1;
+
+ lamda = lamda + stp;
+}
+
+leastLoss = rowIndexMin(t(Rloss));
+bestModel = Rbeta[,as.scalar(leastLoss)];
+
+R = sum(bestModel);
+write(R, $1, format="text");
+
diff --git
a/src/test/scripts/functions/linearization/SystemDS-config-incorrect.xml
b/src/test/scripts/functions/linearization/SystemDS-config-max-parallelize.xml
similarity index 90%
rename from
src/test/scripts/functions/linearization/SystemDS-config-incorrect.xml
rename to
src/test/scripts/functions/linearization/SystemDS-config-max-parallelize.xml
index 62183138b3..25725397dc 100644
--- a/src/test/scripts/functions/linearization/SystemDS-config-incorrect.xml
+++
b/src/test/scripts/functions/linearization/SystemDS-config-max-parallelize.xml
@@ -18,5 +18,5 @@
-->
<root>
-
<sysds.compile.linearization>something_incorrect</sysds.compile.linearization>
+ <sysds.compile.linearization>max_parallelize</sysds.compile.linearization>
</root>