This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new d3480c1ee0 [SYSTEMDS-3562] Save multi-statementblock checkpoints for
recompiler
d3480c1ee0 is described below
commit d3480c1ee032694aaf86839efb5ad656c041f16f
Author: Arnab Phani <[email protected]>
AuthorDate: Wed Jun 14 10:53:34 2023 +0200
[SYSTEMDS-3562] Save multi-statementblock checkpoints for recompiler
This patch adds a temporary fix to save the position of the checkpoint
instructions placed in a loop body during compilation and again place
those in there during recompilation. A better fix would be to enable
recompilation for that loop or the function.
Closes #1844
---
.../apache/sysds/hops/recompile/Recompiler.java | 4 +-
.../org/apache/sysds/lops/rewrite/LopRewriter.java | 4 +-
.../lops/rewrite/RewriteAddChkpointInLoop.java | 6 ++-
.../sysds/lops/rewrite/RewriteAddChkpointLop.java | 47 +++++++++++++++++++++-
.../org/apache/sysds/parser/StatementBlock.java | 18 +++++++++
.../functions/async/CheckpointSharedOpsTest.java | 3 +-
6 files changed, 73 insertions(+), 9 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index d5c8169d68..01945f90d9 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -404,8 +404,8 @@ public class Recompiler {
}
// dynamic lop rewrites for the updated hop DAGs
- if (rewrittenHops)
- _lopRewriter.get().rewriteLopDAG(lops);
+ if (rewrittenHops && sb != null)
+ _lopRewriter.get().rewriteLopDAG(sb, lops);
Dag<Lop> dag = new Dag<>();
for (Lop l : lops)
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
index 55b590543a..2b054d9b2b 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -69,8 +69,8 @@ public class LopRewriter
rRewriteLop(fsb);
}
- public ArrayList<Lop> rewriteLopDAG(ArrayList<Lop> lops) {
- StatementBlock sb = new StatementBlock();
+ public ArrayList<Lop> rewriteLopDAG(StatementBlock sb, ArrayList<Lop>
lops) {
+ //StatementBlock sb = new StatementBlock();
sb.setLops(lops);
return rRewriteLop(sb).get(0).getLops();
}
diff --git
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
index 5e445a6bce..27a1a552ab 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
@@ -81,7 +81,7 @@ public class RewriteAddChkpointInLoop extends LopRewriteRule
return List.of(sb);
// Add checkpoint Lops after the shared operators
- addChkpointLop(lops, operatorJobCount);
+ addChkpointLop(lops, operatorJobCount, csb);
// TODO: A rewrite pass to remove less effective checkpoints
return List.of(sb);
}
@@ -91,7 +91,7 @@ public class RewriteAddChkpointInLoop extends LopRewriteRule
return sbs;
}
- private void addChkpointLop(List<Lop> nodes, Map<Long, Integer>
operatorJobCount) {
+ private void addChkpointLop(List<Lop> nodes, Map<Long, Integer>
operatorJobCount, StatementBlock sb) {
for (Lop l : nodes) {
if(operatorJobCount.containsKey(l.getID()) &&
operatorJobCount.get(l.getID()) > 1) {
// TODO: Check if this lop leads to one of
those variables
@@ -106,6 +106,8 @@ public class RewriteAddChkpointInLoop extends LopRewriteRule
out.replaceInput(l, checkpoint);
l.removeOutput(out);
}
+ // Save the checkpoint position for the
recompiler
+ sb.setCheckpointPosition(l, oldOuts);
}
}
}
diff --git
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
index 6a1a1192ea..701c604bc8 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
@@ -60,6 +60,7 @@ public class RewriteAddChkpointLop extends LopRewriteRule
OperatorOrderingUtils.markSharedSparkOps(sparkRoots,
operatorJobCount);
// TODO: A rewrite pass to remove less effective checkpoints
addChkpointLop(lops, operatorJobCount);
+ placeCompiledCheckpoints(lops, sb);
//New node is added inplace in the Lop DAG
return List.of(sb);
}
@@ -78,7 +79,7 @@ public class RewriteAddChkpointLop extends LopRewriteRule
&&
OperatorOrderingUtils.isPersistableSparkOp(l)) {
// This operation is expensive and shared
between Spark jobs
List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
- // Construct a chkpoint lop that takes this
Spark node as a input
+ // Construct a chkpoint lop that takes this
Spark node as an input
Lop chkpoint = new Checkpoint(l,
l.getDataType(), l.getValueType(),
Checkpoint.getDefaultStorageLevelString(), false);
for (Lop out : oldOuts) {
@@ -90,4 +91,48 @@ public class RewriteAddChkpointLop extends LopRewriteRule
}
}
}
+
+ private void placeCompiledCheckpoints(List<Lop> nodes, StatementBlock
sb) {
+ if (sb.getCheckpointPositions() == null)
+ return;
+
+ for (Lop l : nodes) {
+ // Check if the compiler placed and saved a checkpoint
+ // TODO: Call recompiler on the loops
+ if (isCheckpointed(l, sb)) {
+ List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
+ // Construct a chkpoint lop that takes this
Spark node as an input
+ Lop chkpoint = new Checkpoint(l,
l.getDataType(), l.getValueType(),
+
Checkpoint.getDefaultStorageLevelString(), false);
+ for (Lop out : oldOuts) {
+ //Rewire l -> out to l -> chkpoint ->
out
+ chkpoint.addOutput(out);
+ out.replaceInput(l, chkpoint);
+ l.removeOutput(out);
+ }
+ }
+ }
+ }
+
+ private boolean isCheckpointed(Lop lop, StatementBlock sb) {
+ var cpPositions = sb.getCheckpointPositions();
+ if (cpPositions == null)
+ return false;
+
+ if (cpPositions.containsKey(lop.getType())) {
+ List<Lop.Type> outputsT =
cpPositions.get(lop.getType());
+ List<Lop> outputs = new ArrayList<>(lop.getOutputs());
+ if (outputs.size() != outputsT.size())
+ return false;
+ for (int i=0; i< outputs.size(); i++) {
+ if (outputs.get(i).getType() != outputsT.get(i)
+ || !outputs.get(i).isExecSpark())
+ return false;
+ }
+ }
+ else
+ return false;
+
+ return true;
+ }
}
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index b4ee82405b..3deb6a8001 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
+import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -64,6 +65,7 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
private boolean _requiresRecompile = false;
private boolean _splitDag = false;
private boolean _nondeterministic = false;
+ private HashMap<Lop.Type, List<Lop.Type>> _checkpointPositions = null;
protected double repetitions = 1;
public final static double DEFAULT_LOOP_REPETITIONS = 10;
@@ -1393,4 +1395,20 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
public boolean isNondeterministic() {
return _nondeterministic;
}
+
+ public void setCheckpointPosition(Lop input, List<Lop> outputs) {
+ // FIXME: Type is not the best key as many Lops may have the
same types
+ Lop.Type inputT = input.getType();
+ List<Lop.Type> outputsT =
outputs.stream().map(Lop::getType).collect(Collectors.toList());
+
+ if (_checkpointPositions == null)
+ _checkpointPositions = new HashMap<>();
+ if (!_checkpointPositions.containsKey(inputT)) {
+ _checkpointPositions.put(inputT, outputsT);
+ }
+ }
+
+ public HashMap<Lop.Type, List<Lop.Type>> getCheckpointPositions() {
+ return _checkpointPositions;
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
index bceb4e2090..6898b9ba88 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
@@ -96,8 +96,7 @@ public class CheckpointSharedOpsTest extends
AutomatedTestBase {
if (!matchVal)
System.out.println("Value w/o Checkpoint "+R+"
w/ Checkpoint "+R_mp);
//compare checkpoint instruction count
- if (!testname.equalsIgnoreCase(TEST_NAME+"2"))
- Assert.assertTrue("Violated checkpoint count: "
+ numCP + " < " + numCP_maxp, numCP < numCP_maxp);
+ Assert.assertTrue("Violated checkpoint count: " + numCP
+ " < " + numCP_maxp, numCP < numCP_maxp);
} finally {
resetExecMode(oldPlatform);
InfrastructureAnalyzer.setLocalMaxMemory(oldmem);