This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 5142a7e739 [SYSTEMDS-3497] Refactor to add LOP rewrite step in
compilation
5142a7e739 is described below
commit 5142a7e7390ff816af99c593cf31a0402ae931ee
Author: Arnab Phani <[email protected]>
AuthorDate: Thu Feb 9 14:33:11 2023 +0100
[SYSTEMDS-3497] Refactor to add LOP rewrite step in compilation
This patch adds a new step in the compilation (and recompilation)
steps to rewrite Lop DAGs (single and multi-statement block).
Current rewrite passes include adding prefetch, broadcast and
checkpoint nodes. This refactoring allows us easily add new
rewrite rules and separate the Lop rewrites from operator
ordering.
Closes #1783
---
src/main/java/org/apache/sysds/api/DMLScript.java | 5 +-
.../apache/sysds/hops/recompile/Recompiler.java | 21 +-
src/main/java/org/apache/sysds/lops/Lop.java | 4 +
.../java/org/apache/sysds/lops/LopProperties.java | 1 +
.../apache/sysds/lops/OperatorOrderingUtils.java | 125 +++++++++++
src/main/java/org/apache/sysds/lops/UnaryCP.java | 2 +-
.../lops/compile/linearization/ILinearize.java | 231 ++-------------------
.../apache/sysds/lops/rewrite/LopRewriteRule.java | 30 +++
.../org/apache/sysds/lops/rewrite/LopRewriter.java | 134 ++++++++++++
.../sysds/lops/rewrite/RewriteAddBroadcastLop.java | 83 ++++++++
.../sysds/lops/rewrite/RewriteAddChkpointLop.java | 117 +++++++++++
.../sysds/lops/rewrite/RewriteAddPrefetchLop.java | 118 +++++++++++
.../apache/sysds/lops/rewrite/RewriteFixIDs.java | 67 ++++++
.../org/apache/sysds/parser/DMLTranslator.java | 8 +-
.../test/functions/async/AsyncBroadcastTest.java | 4 -
.../functions/async/CheckpointSharedOpsTest.java | 4 +-
.../test/functions/async/PrefetchRDDTest.java | 1 -
17 files changed, 725 insertions(+), 230 deletions(-)
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java
b/src/main/java/org/apache/sysds/api/DMLScript.java
index 0c3716dd35..ad386e29e8 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -455,8 +455,11 @@ public class DMLScript
//Step 6: construct lops (incl exec type and op selection)
dmlt.constructLops(prog);
+
+ //Step 7: rewrite LOP DAGs (incl adding new LOPs s.a. prefetch,
broadcast)
+ dmlt.rewriteLopDAG(prog);
- //Step 7: generate runtime program, incl codegen
+ //Step 8: generate runtime program, incl codegen
Program rtprog = dmlt.getRuntimeProgram(prog,
ConfigurationManager.getDMLConfig());
//Step 9: prepare statistics [and optional explain output]
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index f3fb0fcbc6..392d303ca4 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -59,6 +59,7 @@ import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriter;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.lops.rewrite.LopRewriter;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.ForStatementBlock;
@@ -130,6 +131,10 @@ public class Recompiler {
private static ThreadLocal<ProgramRewriter> _rewriter = new
ThreadLocal<ProgramRewriter>() {
@Override protected ProgramRewriter initialValue() { return new
ProgramRewriter(false, true); }
};
+
+ private static ThreadLocal<LopRewriter> _lopRewriter = new
ThreadLocal<LopRewriter>() {
+ @Override protected LopRewriter initialValue() {return new
LopRewriter();}
+ };
public enum ResetType {
RESET,
@@ -145,6 +150,7 @@ public class Recompiler {
*/
public static void reinitRecompiler() {
_rewriter.set(new ProgramRewriter(false, true));
+ _lopRewriter.set(new LopRewriter());
}
public static ArrayList<Instruction> recompileHopsDag( StatementBlock
sb, ArrayList<Hop> hops,
@@ -305,6 +311,7 @@ public class Recompiler {
boolean codegen = ConfigurationManager.isCodegenEnabled()
&& !(forceEt && et == null ) //not on reset
&& SpoofCompiler.RECOMPILE_CODEGEN;
+ boolean rewrittenHops = false;
// prepare hops dag for recompile
if( !inplace ){
@@ -352,6 +359,7 @@ public class Recompiler {
Hop.resetVisitStatus(hops);
for( Hop hopRoot : hops )
rUpdateStatistics( hopRoot,
ec.getVariables() );
+ rewrittenHops = true;
}
// refresh memory estimates (based on updated stats,
@@ -382,11 +390,18 @@ public class Recompiler {
rSetMaxParallelism(hops, maxK);
// construct lops
- Dag<Lop> dag = new Dag<>();
+ ArrayList<Lop> lops = new ArrayList<>();
for( Hop hopRoot : hops ){
- Lop lops = hopRoot.constructLops();
- lops.addToDag(dag);
+ lops.add(hopRoot.constructLops());
}
+
+ // dynamic lop rewrites for the updated hop DAGs
+ if (rewrittenHops)
+ _lopRewriter.get().rewriteLopDAG(lops);
+
+ Dag<Lop> dag = new Dag<>();
+ for (Lop l : lops)
+ l.addToDag(dag);
// generate runtime instructions (incl piggybacking)
ArrayList<Instruction> newInst = dag
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java
b/src/main/java/org/apache/sysds/lops/Lop.java
index ecc9e7f893..b768ded9ad 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -401,6 +401,10 @@ public abstract class Lop
public long getID() {
return lps.getID();
}
+
+ public void setNewID() {
+ lps.setNewID();
+ }
public int getLevel() {
return lps.getLevel();
diff --git a/src/main/java/org/apache/sysds/lops/LopProperties.java
b/src/main/java/org/apache/sysds/lops/LopProperties.java
index ed788c79fa..9fce0b6fb0 100644
--- a/src/main/java/org/apache/sysds/lops/LopProperties.java
+++ b/src/main/java/org/apache/sysds/lops/LopProperties.java
@@ -54,6 +54,7 @@ public class LopProperties
}
public long getID() { return ID; }
+ public void setNewID() { ID = UniqueLopID.getNextID(); }
public int getLevel() { return level; }
public void setLevel( int l ) { level = l; }
diff --git a/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
b/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
new file mode 100644
index 0000000000..35926961f2
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+public class OperatorOrderingUtils
+{
+ // Return a list representation of all the lops in a SB
+ public static ArrayList<Lop> getLopList(StatementBlock sb) {
+ ArrayList<Lop> lops = null;
+ if (sb.getLops() != null && !sb.getLops().isEmpty()) {
+ lops = new ArrayList<>();
+ for (Lop root : sb.getLops())
+ addToLopList(lops, root);
+ }
+ return lops;
+ }
+
+ // Determine if a lop is root of a DAG
+ public static boolean isLopRoot(Lop lop) {
+ if (lop.getOutputs().isEmpty())
+ return true;
+ //TODO: Handle internal builtins (e.g. eigen)
+ if (lop instanceof FunctionCallCP &&
+ ((FunctionCallCP)
lop).getFnamespace().equalsIgnoreCase(DMLProgram.INTERNAL_NAMESPACE)) {
+ return true;
+ }
+ return false;
+ }
+
+ // Gather the Spark operators which return intermediates to local
(actions/single_block)
+ // In addition count the number of Spark OPs underneath every Operator
+ public static int collectSparkRoots(Lop root, Map<Long, Integer>
sparkOpCount, List<Lop> sparkRoots) {
+ if (sparkOpCount.containsKey(root.getID())) //visited before
+ return sparkOpCount.get(root.getID());
+
+ // Aggregate #Spark operators in the child DAGs
+ int total = 0;
+ for (Lop input : root.getInputs())
+ total += collectSparkRoots(input, sparkOpCount,
sparkRoots);
+
+ // Check if this node is Spark
+ total = root.isExecSpark() ? total + 1 : total;
+ sparkOpCount.put(root.getID(), total);
+
+ // Triggering point: Spark action/operator with all CP consumers
+ if (isSparkTriggeringOp(root)) {
+ sparkRoots.add(root);
+ root.setAsynchronous(true); //candidate for async.
execution
+ }
+
+ return total;
+ }
+
+ // Dictionary of Spark operators which are expensive enough to be
+ // benefited from persisting if shared among jobs.
+ public static boolean isPersistableSparkOp(Lop lop) {
+ return lop.isExecSpark() && (lop instanceof MapMult
+ || lop instanceof MMCJ || lop instanceof MMRJ
+ || lop instanceof MMZip);
+ }
+
+ private static boolean isSparkTriggeringOp(Lop lop) {
+ boolean rightSpLop = lop.isExecSpark() && (lop.getAggType() ==
AggBinaryOp.SparkAggType.SINGLE_BLOCK
+ || lop.getDataType() == Types.DataType.SCALAR || lop
instanceof MapMultChain
+ || lop instanceof PickByCount || lop instanceof MMZip
|| lop instanceof CentralMoment
+ || lop instanceof CoVariance || lop instanceof MMTSJ ||
lop.isAllOutputsCP());
+ boolean isPrefetched = lop.getOutputs().size() == 1
+ && lop.getOutputs().get(0) instanceof UnaryCP
+ && ((UnaryCP)
lop.getOutputs().get(0)).getOpCode().equalsIgnoreCase("prefetch");
+ boolean col2Bc = isCollectForBroadcast(lop);
+ boolean prefetch = (lop instanceof UnaryCP) &&
+ ((UnaryCP)
lop).getOpCode().equalsIgnoreCase("prefetch");
+ return (rightSpLop || col2Bc || prefetch) && !isPrefetched;
+ }
+
+ // Determine if the result of this operator is collected to
+ // broadcast for the next operator (e.g. mapmm --> map+)
+ public static boolean isCollectForBroadcast(Lop lop) {
+ boolean isSparkOp = lop.isExecSpark();
+ boolean isBc = lop.getOutputs().stream()
+ .allMatch(out -> (out.getBroadcastInput() == lop));
+ //TODO: Handle Lops with mixed Spark (broadcast) CP consumers
+ return isSparkOp && isBc && (lop.getDataType() ==
Types.DataType.MATRIX);
+ }
+
+ private static boolean addNode(ArrayList<Lop> lops, Lop node) {
+ if (lops.contains(node))
+ return false;
+ lops.add(node);
+ return true;
+ }
+
+ private static void addToLopList(ArrayList<Lop> lops, Lop lop) {
+ if (addNode(lops, lop))
+ for (Lop in : lop.getInputs())
+ addToLopList(lops, in);
+ }
+
+}
diff --git a/src/main/java/org/apache/sysds/lops/UnaryCP.java
b/src/main/java/org/apache/sysds/lops/UnaryCP.java
index 4b95e8c1b0..7dd6a30e58 100644
--- a/src/main/java/org/apache/sysds/lops/UnaryCP.java
+++ b/src/main/java/org/apache/sysds/lops/UnaryCP.java
@@ -66,7 +66,7 @@ public class UnaryCP extends Lop {
return "Operation: " + getInstructions("", "");
}
- private String getOpCode() {
+ public String getOpCode() {
return operation.toString();
}
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index eef3085917..656a0262d6 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -34,7 +34,6 @@ import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ExecType;
-import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
@@ -43,24 +42,19 @@ import org.apache.sysds.lops.CSVReBlock;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.CoVariance;
-import org.apache.sysds.lops.DataGen;
-import org.apache.sysds.lops.FunctionCallCP;
import org.apache.sysds.lops.GroupedAggregate;
import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
-import org.apache.sysds.lops.MMCJ;
-import org.apache.sysds.lops.MMRJ;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.lops.MMZip;
-import org.apache.sysds.lops.MapMult;
import org.apache.sysds.lops.MapMultChain;
+import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.ParameterizedBuiltin;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.ReBlock;
import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.UnaryCP;
-import org.apache.sysds.parser.DMLProgram;
/**
* A interface for the linearization algorithms that order the DAG nodes into
a sequence of instructions to execute.
@@ -187,15 +181,17 @@ public interface ILinearize {
private static List<Lop> doMaxParallelizeSort(List<Lop> v)
{
List<Lop> final_v = null;
- if (v.stream().anyMatch(ILinearize::isSparkTriggeringOp)) {
+ // Fallback to default depth-first if all operators are CP
+ if (v.stream().anyMatch(ILinearize::isDistributedOp)) {
// Step 1: Collect the Spark roots and #Spark
instructions in each subDAG
Map<Long, Integer> sparkOpCount = new HashMap<>();
- List<Lop> roots =
v.stream().filter(ILinearize::isRoot).collect(Collectors.toList());
+ List<Lop> roots =
v.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
List<Lop> sparkRoots = new ArrayList<>();
- roots.forEach(r -> collectSparkRoots(r, sparkOpCount,
sparkRoots));
+ roots.forEach(r ->
OperatorOrderingUtils.collectSparkRoots(r, sparkOpCount, sparkRoots));
- // Step 2: Depth-first linearization. Place the CP OPs
first to increase broadcast potentials.
+ // Step 2: Depth-first linearization of Spark roots.
// Maintain the default order (by ID) to trigger
independent Spark jobs first
+ // This allows parallel execution of the jobs in the
cluster
ArrayList<Lop> operatorList = new ArrayList<>();
sparkRoots.forEach(r -> depthFirst(r, operatorList,
sparkOpCount, false));
@@ -204,81 +200,12 @@ public interface ILinearize {
roots.forEach(r -> depthFirst(r, operatorList,
sparkOpCount, false));
roots.forEach(Lop::resetVisitStatus);
- // Step 4: Add Chkpoint lops after the expensive Spark
operators, which
- // are shared among multiple Spark jobs. Only consider
operators with
- // Spark consumers for now.
- Map<Long, Integer> operatorJobCount = new HashMap<>();
- markPersistableSparkOps(sparkRoots, operatorJobCount);
- final_v = addChkpointLop(operatorList,
operatorJobCount);
- // TODO: A rewrite pass to remove less effective
chkpoints
+ final_v = operatorList;
}
else
- // Fall back to depth if none of the operators returns
results back to local
final_v = depthFirst(v);
- // Step 4: Add Prefetch and Broadcast lops if necessary
- List<Lop> v_pf = ConfigurationManager.isPrefetchEnabled() ?
addPrefetchLop(final_v) : final_v;
- List<Lop> v_bc = ConfigurationManager.isBroadcastEnabled() ?
addBroadcastLop(v_pf) : v_pf;
-
- return v_bc;
- }
-
- private static boolean isRoot(Lop lop) {
- if (lop.getOutputs().isEmpty())
- return true;
- if (lop instanceof FunctionCallCP &&
- ((FunctionCallCP)
lop).getFnamespace().equalsIgnoreCase(DMLProgram.INTERNAL_NAMESPACE)) {
- return true;
- }
- return false;
- }
-
- // Gather the Spark operators which return intermediates to local
(actions/single_block)
- // In addition count the number of Spark OPs underneath every Operator
- private static int collectSparkRoots(Lop root, Map<Long, Integer>
sparkOpCount, List<Lop> sparkRoots) {
- if (sparkOpCount.containsKey(root.getID())) //visited before
- return sparkOpCount.get(root.getID());
-
- // Aggregate #Spark operators in the child DAGs
- int total = 0;
- for (Lop input : root.getInputs())
- total += collectSparkRoots(input, sparkOpCount,
sparkRoots);
-
- // Check if this node is Spark
- total = root.isExecSpark() ? total + 1 : total;
- sparkOpCount.put(root.getID(), total);
-
- // Triggering point: Spark action/operator with all CP consumers
- if (isSparkTriggeringOp(root)) {
- sparkRoots.add(root);
- root.setAsynchronous(true); //candidate for async.
execution
- }
-
- return total;
- }
-
- // Count the number of jobs a Spark operator is part of
- private static void markPersistableSparkOps(List<Lop> sparkRoots,
Map<Long, Integer> operatorJobCount) {
- for (Lop root : sparkRoots) {
- collectPersistableSparkOps(root, operatorJobCount);
- root.resetVisitStatus();
- }
- }
-
- private static void collectPersistableSparkOps(Lop root, Map<Long,
Integer> operatorJobCount) {
- if (root.isVisited())
- return;
-
- for (Lop input : root.getInputs())
- if (root.getBroadcastInput() != input)
- collectPersistableSparkOps(input,
operatorJobCount);
-
- // Increment the job counter if this node benefits from
persisting
- // and reachable from multiple job roots
- if (isPersistableSparkOp(root))
- operatorJobCount.merge(root.getID(), 1, Integer::sum);
-
- root.setVisited();
+ return final_v;
}
// Place the operators in a depth-first manner, but order
@@ -306,104 +233,11 @@ public interface ILinearize {
root.setVisited();
}
- private static boolean isSparkTriggeringOp(Lop lop) {
- return lop.isExecSpark() && (lop.getAggType() ==
SparkAggType.SINGLE_BLOCK
- || lop.getDataType() == DataType.SCALAR || lop
instanceof MapMultChain
- || lop instanceof PickByCount || lop instanceof MMZip
|| lop instanceof CentralMoment
- || lop instanceof CoVariance || lop instanceof MMTSJ ||
lop.isAllOutputsCP())
- || isCollectForBroadcast(lop);
- }
-
- private static boolean isCollectForBroadcast(Lop lop) {
- boolean isSparkOp = lop.isExecSpark();
- boolean isBc = lop.getOutputs().stream()
- .allMatch(out -> (out.getBroadcastInput() == lop));
- //TODO: Handle Lops with mixed Spark (broadcast) CP consumers
- return isSparkOp && isBc && (lop.getDataType() ==
DataType.MATRIX);
- }
-
- // Dictionary of Spark operators which are expensive enough to be
- // benefited from persisting if shared among jobs.
- private static boolean isPersistableSparkOp(Lop lop) {
- return lop.isExecSpark() && (lop instanceof MapMult
- || lop instanceof MMCJ || lop instanceof MMRJ
- || lop instanceof MMZip);
- }
-
- private static List<Lop> addChkpointLop(List<Lop> nodes, Map<Long,
Integer> operatorJobCount) {
- List<Lop> nodesWithChkpt = new ArrayList<>();
-
- for (Lop l : nodes) {
- nodesWithChkpt.add(l);
- if(operatorJobCount.containsKey(l.getID()) &&
operatorJobCount.get(l.getID()) > 1) {
- //This operation is expensive and shared
between Spark jobs
- List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
- //Construct a chkpoint lop that takes this
Spark node as a input
- Lop chkpoint = new Checkpoint(l,
l.getDataType(), l.getValueType(),
-
Checkpoint.getDefaultStorageLevelString(), false);
- for (Lop out : oldOuts) {
- //Rewire l -> out to l -> chkpoint ->
out
- chkpoint.addOutput(out);
- out.replaceInput(l, chkpoint);
- l.removeOutput(out);
- }
- //Place it immediately after the Spark lop in
the node list
- nodesWithChkpt.add(chkpoint);
- }
- }
- return nodesWithChkpt;
- }
-
- private static List<Lop> addPrefetchLop(List<Lop> nodes) {
- List<Lop> nodesWithPrefetch = new ArrayList<>();
-
- //Find the Spark nodes with all CP outputs
- for (Lop l : nodes) {
- nodesWithPrefetch.add(l);
- if (isPrefetchNeeded(l)) {
- List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
- //Construct a Prefetch lop that takes this
Spark node as a input
- UnaryCP prefetch = new UnaryCP(l,
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
- prefetch.setAsynchronous(true);
- //Reset asynchronous flag for the input if
already set (e.g. mapmm -> prefetch)
- l.setAsynchronous(false);
- for (Lop outCP : oldOuts) {
- //Rewire l -> outCP to l -> Prefetch ->
outCP
- prefetch.addOutput(outCP);
- outCP.replaceInput(l, prefetch);
- l.removeOutput(outCP);
- //FIXME: Rewire _inputParams when
needed (e.g. GroupedAggregate)
- }
- //Place it immediately after the Spark lop in
the node list
- nodesWithPrefetch.add(prefetch);
- }
- }
- return nodesWithPrefetch;
- }
-
- private static List<Lop> addBroadcastLop(List<Lop> nodes) {
- List<Lop> nodesWithBroadcast = new ArrayList<>();
-
- for (Lop l : nodes) {
- nodesWithBroadcast.add(l);
- if (isBroadcastNeeded(l)) {
- List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
- //Construct a Broadcast lop that takes this
Spark node as an input
- UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST,
l.getDataType(), l.getValueType(), ExecType.CP);
- bc.setAsynchronous(true);
- //FIXME: Wire Broadcast only with the necessary
outputs
- for (Lop outCP : oldOuts) {
- //Rewire l -> outCP to l -> Broadcast
-> outCP
- bc.addOutput(outCP);
- outCP.replaceInput(l, bc);
- l.removeOutput(outCP);
- //FIXME: Rewire _inputParams when
needed (e.g. GroupedAggregate)
- }
- //Place it immediately after the Spark lop in
the node list
- nodesWithBroadcast.add(bc);
- }
- }
- return nodesWithBroadcast;
+ private static boolean isDistributedOp(Lop lop) {
+ return lop.isExecSpark()
+ || (lop instanceof UnaryCP
+ && (((UnaryCP)
lop).getOpCode().equalsIgnoreCase("prefetch")
+ || ((UnaryCP)
lop).getOpCode().equalsIgnoreCase("broadcast")));
}
@SuppressWarnings("unused")
@@ -432,43 +266,6 @@ public interface ILinearize {
return nodesWithCheckpoint;
}
- private static boolean isPrefetchNeeded(Lop lop) {
- // Run Prefetch for a Spark instruction if the instruction is a
Transformation
- // and the output is consumed by only CP instructions.
- boolean transformOP = lop.getExecType() == ExecType.SPARK &&
lop.getAggType() != SparkAggType.SINGLE_BLOCK
- // Always Action operations
- && !(lop.getDataType() == DataType.SCALAR)
- && !(lop instanceof MapMultChain) && !(lop
instanceof PickByCount)
- && !(lop instanceof MMZip) && !(lop instanceof
CentralMoment)
- && !(lop instanceof CoVariance)
- // Not qualified for prefetching
- && !(lop instanceof Checkpoint) && !(lop
instanceof ReBlock)
- && !(lop instanceof CSVReBlock) && !(lop
instanceof DataGen)
- // Cannot filter Transformation cases from
Actions (FIXME)
- && !(lop instanceof MMTSJ) && !(lop instanceof
UAggOuterChain)
- && !(lop instanceof ParameterizedBuiltin) &&
!(lop instanceof SpoofFused);
-
- //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
- boolean hasParameterizedOut = lop.getOutputs().stream()
- .anyMatch(out -> ((out instanceof
ParameterizedBuiltin)
- || (out instanceof GroupedAggregate)
- || (out instanceof GroupedAggregateM)));
- //TODO: support non-matrix outputs
- return transformOP && !hasParameterizedOut
- && (lop.isAllOutputsCP() ||
isCollectForBroadcast(lop))
- && lop.getDataType() == DataType.MATRIX;
- }
-
- private static boolean isBroadcastNeeded(Lop lop) {
- // Asynchronously broadcast a matrix if that is produced by a
CP instruction,
- // and at least one Spark parent needs to broadcast this
intermediate (eg. mapmm)
- boolean isBc = lop.getOutputs().stream()
- .anyMatch(out -> (out.getBroadcastInput() ==
lop));
- //TODO: Early broadcast objects that are bigger than a single
block
- //return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
- return isBc && lop.getDataType() == DataType.MATRIX;
- }
-
private static boolean isCheckpointNeeded(Lop lop) {
// Place checkpoint_e just before a Spark action (FIXME)
boolean actionOP = lop.getExecType() == ExecType.SPARK
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriteRule.java
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriteRule.java
new file mode 100644
index 0000000000..5af5d65244
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriteRule.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.List;
+
+public abstract class LopRewriteRule
+{
+ public abstract List<StatementBlock>
rewriteLOPinStatementBlock(StatementBlock sb);
+ public abstract List<StatementBlock>
rewriteLOPinStatementBlocks(List<StatementBlock> sb);
+}
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
new file mode 100644
index 0000000000..4567cf1c4e
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -0,0 +1,134 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class LopRewriter
+{
+ private ArrayList<LopRewriteRule> _lopSBRuleSet = null;
+
+ public LopRewriter() {
+ _lopSBRuleSet = new ArrayList<>();
+ // Add rewrite rules (single and multi-statement block)
+ _lopSBRuleSet.add(new RewriteAddPrefetchLop());
+ _lopSBRuleSet.add(new RewriteAddBroadcastLop());
+ _lopSBRuleSet.add(new RewriteAddChkpointLop());
+ // TODO: A rewrite pass to remove less effective chkpoints
+ // Last rewrite to reset Lop IDs in a depth-first manner
+ _lopSBRuleSet.add(new RewriteFixIDs());
+ }
+
+ public void rewriteProgramLopDAGs(DMLProgram dmlp) {
+ for (String namespaceKey : dmlp.getNamespaces().keySet())
+ // for each namespace, handle function statement blocks
+ for (String fname :
dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
+ FunctionStatementBlock fsblock =
dmlp.getFunctionStatementBlock(namespaceKey,fname);
+ rewriteLopDAGsFunction(fsblock);
+ }
+
+ if (!_lopSBRuleSet.isEmpty()) {
+ ArrayList<StatementBlock> sbs =
rRewriteLops(dmlp.getStatementBlocks());
+ dmlp.setStatementBlocks(sbs);
+ }
+ }
+
+ public void rewriteLopDAGsFunction(FunctionStatementBlock fsb) {
+ if( !_lopSBRuleSet.isEmpty() )
+ rRewriteLop(fsb);
+ }
+
+ public ArrayList<Lop> rewriteLopDAG(ArrayList<Lop> lops) {
+ StatementBlock sb = new StatementBlock();
+ sb.setLops(lops);
+ return rRewriteLop(sb).get(0).getLops();
+ }
+
+ public ArrayList<StatementBlock> rRewriteLops(ArrayList<StatementBlock>
sbs) {
+ // Apply rewrite rules to the lops of the list of statement
blocks
+ List<StatementBlock> tmp = sbs;
+ for(LopRewriteRule r : _lopSBRuleSet)
+ tmp = r.rewriteLOPinStatementBlocks(tmp);
+
+ // Recursively rewrite lops in statement blocks
+ List<StatementBlock> tmp2 = new ArrayList<>();
+ for( StatementBlock sb : tmp )
+ tmp2.addAll(rRewriteLop(sb));
+
+ // Prepare output list
+ sbs.clear();
+ sbs.addAll(tmp2);
+ return sbs;
+ }
+
+ public ArrayList<StatementBlock> rRewriteLop(StatementBlock sb) {
+ ArrayList<StatementBlock> ret = new ArrayList<>();
+ ret.add(sb);
+
+ // Recursive invocation
+ if (sb instanceof FunctionStatementBlock) {
+ FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
+ FunctionStatement fstmt =
(FunctionStatement)fsb.getStatement(0);
+ fstmt.setBody(rRewriteLops(fstmt.getBody()));
+ }
+ else if (sb instanceof WhileStatementBlock) {
+ WhileStatementBlock wsb = (WhileStatementBlock) sb;
+ WhileStatement wstmt =
(WhileStatement)wsb.getStatement(0);
+ wstmt.setBody(rRewriteLops(wstmt.getBody()));
+ }
+ else if (sb instanceof IfStatementBlock) {
+ IfStatementBlock isb = (IfStatementBlock) sb;
+ IfStatement istmt = (IfStatement)isb.getStatement(0);
+ istmt.setIfBody(rRewriteLops(istmt.getIfBody()));
+ istmt.setElseBody(rRewriteLops(istmt.getElseBody()));
+ }
+ else if (sb instanceof ForStatementBlock) { //incl parfor
+ //TODO: parfor statement blocks
+ ForStatementBlock fsb = (ForStatementBlock) sb;
+ ForStatement fstmt = (ForStatement)fsb.getStatement(0);
+ fstmt.setBody(rRewriteLops(fstmt.getBody()));
+ }
+
+ // Apply rewrite rules to individual statement blocks
+ for(LopRewriteRule r : _lopSBRuleSet) {
+ ArrayList<StatementBlock> tmp = new ArrayList<>();
+ for( StatementBlock sbc : ret )
+ tmp.addAll( r.rewriteLOPinStatementBlock(sbc) );
+
+ // Take over set of rewritten sbs
+ ret.clear();
+ ret.addAll(tmp);
+ }
+
+ return ret;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java
new file mode 100644
index 0000000000..da22c51186
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.lops.UnaryCP;
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class RewriteAddBroadcastLop extends LopRewriteRule
+{
+ @Override
+ public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock
sb)
+ {
+ if (!ConfigurationManager.isBroadcastEnabled())
+ return List.of(sb);
+
+ ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
+ if (lops == null)
+ return List.of(sb);
+
+ ArrayList<Lop> nodesWithBroadcast = new ArrayList<>();
+ for (Lop l : lops) {
+ nodesWithBroadcast.add(l);
+ if (isBroadcastNeeded(l)) {
+ List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
+ // Construct a Broadcast lop that takes this
Spark node as an input
+ UnaryCP bc = new UnaryCP(l,
Types.OpOp1.BROADCAST, l.getDataType(), l.getValueType(), Types.ExecType.CP);
+ bc.setAsynchronous(true);
+ //FIXME: Wire Broadcast only with the necessary
outputs
+ for (Lop outCP : oldOuts) {
+ // Rewire l -> outCP to l -> Broadcast
-> outCP
+ bc.addOutput(outCP);
+ outCP.replaceInput(l, bc);
+ l.removeOutput(outCP);
+ //FIXME: Rewire _inputParams when
needed (e.g. GroupedAggregate)
+ }
+ //Place it immediately after the Spark lop in
the node list
+ nodesWithBroadcast.add(bc);
+ }
+ }
+ // New node is added inplace in the Lop DAG
+ return Arrays.asList(sb);
+ }
+
+ @Override
+ public List<StatementBlock>
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+ return sbs;
+ }
+
+ private static boolean isBroadcastNeeded(Lop lop) {
+ // Asynchronously broadcast a matrix if that is produced by a
CP instruction,
+ // and at least one Spark parent needs to broadcast this
intermediate (eg. mapmm)
+ boolean isBc = lop.getOutputs().stream()
+ .anyMatch(out -> (out.getBroadcastInput() == lop));
+ //TODO: Early broadcast objects that are bigger than a single
block
+ boolean isCP = lop.getExecType() == Types.ExecType.CP;
+ return isCP && isBc && lop.getDataType() ==
Types.DataType.MATRIX;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
new file mode 100644
index 0000000000..d4f976a795
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.lops.Checkpoint;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+public class RewriteAddChkpointLop extends LopRewriteRule
+{
+ @Override
+ public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock
sb)
+ {
+ if (!ConfigurationManager.isCheckpointEnabled())
+ return List.of(sb);
+
+ ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
+ if (lops == null)
+ return List.of(sb);
+
+ // Collect the Spark roots and #Spark instructions in each
subDAG
+ List<Lop> sparkRoots = new ArrayList<>();
+ Map<Long, Integer> sparkOpCount = new HashMap<>();
+ List<Lop> roots =
lops.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
+ roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r,
sparkOpCount, sparkRoots));
+ if (sparkRoots.isEmpty())
+ return List.of(sb);
+
+ // Add Chkpoint lops after the expensive Spark operators, which
are
+ // shared among multiple Spark jobs. Only consider operators
with
+ // Spark consumers for now.
+ Map<Long, Integer> operatorJobCount = new HashMap<>();
+ markPersistableSparkOps(sparkRoots, operatorJobCount);
+ // TODO: A rewrite pass to remove less effective chkpoints
+ List<Lop> nodesWithChkpt = addChkpointLop(lops,
operatorJobCount);
+ //New node is added inplace in the Lop DAG
+ return List.of(sb);
+ }
+
+ @Override
+ public List<StatementBlock>
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+ return sbs;
+ }
+
+ private static List<Lop> addChkpointLop(List<Lop> nodes, Map<Long,
Integer> operatorJobCount) {
+ List<Lop> nodesWithChkpt = new ArrayList<>();
+
+ for (Lop l : nodes) {
+ nodesWithChkpt.add(l);
+ if(operatorJobCount.containsKey(l.getID()) &&
operatorJobCount.get(l.getID()) > 1) {
+ // This operation is expensive and shared
between Spark jobs
+ List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
+ // Construct a chkpoint lop that takes this
Spark node as a input
+ Lop chkpoint = new Checkpoint(l,
l.getDataType(), l.getValueType(),
+
Checkpoint.getDefaultStorageLevelString(), false);
+ for (Lop out : oldOuts) {
+ //Rewire l -> out to l -> chkpoint ->
out
+ chkpoint.addOutput(out);
+ out.replaceInput(l, chkpoint);
+ l.removeOutput(out);
+ }
+ // Place it immediately after the Spark lop in
the node list
+ nodesWithChkpt.add(chkpoint);
+ }
+ }
+ return nodesWithChkpt;
+ }
+
+ // Count the number of jobs a Spark operator is part of
+ private static void markPersistableSparkOps(List<Lop> sparkRoots,
Map<Long, Integer> operatorJobCount) {
+ for (Lop root : sparkRoots) {
+ collectPersistableSparkOps(root, operatorJobCount);
+ root.resetVisitStatus();
+ }
+ }
+
+ private static void collectPersistableSparkOps(Lop root, Map<Long,
Integer> operatorJobCount) {
+ if (root.isVisited())
+ return;
+
+ for (Lop input : root.getInputs())
+ if (root.getBroadcastInput() != input)
+ collectPersistableSparkOps(input,
operatorJobCount);
+
+ // Increment the job counter if this node benefits from
persisting
+ // and reachable from multiple job roots
+ if (OperatorOrderingUtils.isPersistableSparkOp(root))
+ operatorJobCount.merge(root.getID(), 1, Integer::sum);
+
+ root.setVisited();
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
new file mode 100644
index 0000000000..6eb52e0d9f
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.lops.CSVReBlock;
+import org.apache.sysds.lops.CentralMoment;
+import org.apache.sysds.lops.Checkpoint;
+import org.apache.sysds.lops.CoVariance;
+import org.apache.sysds.lops.DataGen;
+import org.apache.sysds.lops.GroupedAggregate;
+import org.apache.sysds.lops.GroupedAggregateM;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.MMTSJ;
+import org.apache.sysds.lops.MMZip;
+import org.apache.sysds.lops.MapMultChain;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.lops.ParameterizedBuiltin;
+import org.apache.sysds.lops.PickByCount;
+import org.apache.sysds.lops.ReBlock;
+import org.apache.sysds.lops.SpoofFused;
+import org.apache.sysds.lops.UAggOuterChain;
+import org.apache.sysds.lops.UnaryCP;
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class RewriteAddPrefetchLop extends LopRewriteRule
+{
+ @Override
+ public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock
sb)
+ {
+ if (!ConfigurationManager.isPrefetchEnabled())
+ return List.of(sb);
+
+ ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
+ if (lops == null)
+ return List.of(sb);
+
+ ArrayList<Lop> nodesWithPrefetch = new ArrayList<>();
+ //Find the Spark nodes with all CP outputs
+ for (Lop l : lops) {
+ nodesWithPrefetch.add(l);
+ if (isPrefetchNeeded(l)) {
+ List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
+ //Construct a Prefetch lop that takes this
Spark node as a input
+ UnaryCP prefetch = new UnaryCP(l,
Types.OpOp1.PREFETCH, l.getDataType(), l.getValueType(), Types.ExecType.CP);
+ prefetch.setAsynchronous(true);
+ //Reset asynchronous flag for the input if
already set (e.g. mapmm -> prefetch)
+ l.setAsynchronous(false);
+ for (Lop outCP : oldOuts) {
+ //Rewire l -> outCP to l -> Prefetch ->
outCP
+ prefetch.addOutput(outCP);
+ outCP.replaceInput(l, prefetch);
+ l.removeOutput(outCP);
+ //FIXME: Rewire _inputParams when
needed (e.g. GroupedAggregate)
+ }
+ //Place it immediately after the Spark lop in
the node list
+ nodesWithPrefetch.add(prefetch);
+ }
+ }
+ //New node is added inplace in the Lop DAG
+ return Arrays.asList(sb);
+ }
+
+ @Override
+ public List<StatementBlock>
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+ return sbs;
+ }
+
+ private boolean isPrefetchNeeded(Lop lop) {
+ // Run Prefetch for a Spark instruction if the instruction is a
Transformation
+ // and the output is consumed by only CP instructions.
+ boolean transformOP = lop.getExecType() == Types.ExecType.SPARK
&& lop.getAggType() != AggBinaryOp.SparkAggType.SINGLE_BLOCK
+ // Always Action operations
+ && !(lop.getDataType() == Types.DataType.SCALAR)
+ && !(lop instanceof MapMultChain) && !(lop instanceof
PickByCount)
+ && !(lop instanceof MMZip) && !(lop instanceof
CentralMoment)
+ && !(lop instanceof CoVariance)
+ // Not qualified for prefetching
+ && !(lop instanceof Checkpoint) && !(lop instanceof
ReBlock)
+ && !(lop instanceof CSVReBlock) && !(lop instanceof
DataGen)
+ // Cannot filter Transformation cases from Actions
(FIXME)
+ && !(lop instanceof MMTSJ) && !(lop instanceof
UAggOuterChain)
+ && !(lop instanceof ParameterizedBuiltin) && !(lop
instanceof SpoofFused);
+
+ //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
+ boolean hasParameterizedOut = lop.getOutputs().stream()
+ .anyMatch(out -> ((out instanceof ParameterizedBuiltin)
+ || (out instanceof GroupedAggregate)
+ || (out instanceof GroupedAggregateM)));
+ //TODO: support non-matrix outputs
+ return transformOP && !hasParameterizedOut
+ && (lop.isAllOutputsCP() ||
OperatorOrderingUtils.isCollectForBroadcast(lop))
+ && lop.getDataType() == Types.DataType.MATRIX;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
new file mode 100644
index 0000000000..00d205b553
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.List;
+
+public class RewriteFixIDs extends LopRewriteRule
+{
+ @Override
+ public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock
sb)
+ {
+ // Skip if no new Lop nodes are added
+ if (!ConfigurationManager.isPrefetchEnabled() &&
!ConfigurationManager.isBroadcastEnabled()
+ && !ConfigurationManager.isCheckpointEnabled())
+ return List.of(sb);
+
+ // Reset the IDs in a depth-first manner
+ if (sb.getLops() != null && !sb.getLops().isEmpty()) {
+ for (Lop root : sb.getLops())
+ assignNewID(root);
+ sb.getLops().forEach(Lop::resetVisitStatus);
+ }
+ return List.of(sb);
+ }
+
+ @Override
+ public List<StatementBlock>
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+ return sbs;
+ }
+
+ private void assignNewID(Lop lop) {
+ if (lop.isVisited())
+ return;
+
+ if (lop.getInputs().isEmpty()) { //leaf node
+ lop.setNewID();
+ lop.setVisited();
+ return;
+ }
+ for (Lop input : lop.getInputs())
+ assignNewID(input);
+
+ lop.setNewID();
+ lop.setVisited();
+ }
+}
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 03e2856e6c..c6ed2c5b84 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -78,6 +78,7 @@ import org.apache.sysds.hops.rewrite.ProgramRewriter;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopsException;
import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.lops.rewrite.LopRewriter;
import org.apache.sysds.parser.PrintStatement.PRINTTYPE;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
@@ -312,6 +313,11 @@ public class DMLTranslator
codgenHopsDAG(dmlp);
}
}
+
+ public void rewriteLopDAG(DMLProgram dmlp) {
+ LopRewriter rewriter = new LopRewriter();
+ rewriter.rewriteProgramLopDAGs(dmlp);
+ }
public void codgenHopsDAG(DMLProgram dmlp) {
SpoofCompiler.generateCode(dmlp);
@@ -482,7 +488,7 @@ public class DMLTranslator
}
public ProgramBlock createRuntimeProgramBlock(Program prog,
StatementBlock sb, DMLConfig config) {
- Dag<Lop> dag = null;
+ Dag<Lop> dag = null;
Dag<Lop> pred_dag = null;
ArrayList<Instruction> instruct;
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
index c8b7fdd94f..800655a8fc 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
@@ -71,8 +71,6 @@ public class AsyncBroadcastTest extends AutomatedTestBase {
InfrastructureAnalyzer.setLocalMaxMemory(mem);
try {
- //OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false;
- //OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false;
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
getAndLoadTestConfiguration(testname);
fullDMLScriptName = getScript();
@@ -88,11 +86,9 @@ public class AsyncBroadcastTest extends AutomatedTestBase {
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
- OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
OptimizerUtils.ASYNC_BROADCAST_SPARK = true;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
OptimizerUtils.ASYNC_BROADCAST_SPARK = false;
- OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
HashMap<MatrixValue.CellIndex, Double> R_bc =
readDMLScalarFromOutputDir("R");
//compare matrices
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
index 1a899d3d66..eda92023b9 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
@@ -79,11 +79,11 @@ public class CheckpointSharedOpsTest extends
AutomatedTestBase {
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
long numCP =
Statistics.getCPHeavyHitterCount("sp_chkpoint");
- OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
+ OptimizerUtils.ASYNC_CHECKPOINT_SPARK = true;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R_mp =
readDMLScalarFromOutputDir("R");
long numCP_maxp =
Statistics.getCPHeavyHitterCount("sp_chkpoint");
- OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
+ OptimizerUtils.ASYNC_CHECKPOINT_SPARK = false;
//compare matrices
boolean matchVal = TestUtils.compareMatrices(R, R_mp,
1e-6, "Origin", "withPrefetch");
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index b863c81a29..f821af5eb0 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -97,7 +97,6 @@ public class PrefetchRDDTest extends AutomatedTestBase {
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
- OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
OptimizerUtils.ASYNC_PREFETCH_SPARK = false;