This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new afa315d831 [SYSTEMDS-3443] Rename triggerremote instruction to
checkpoint_e
afa315d831 is described below
commit afa315d8315b238ad29edf5e64c3d6954afd8cfd
Author: Arnab Phani <[email protected]>
AuthorDate: Wed Nov 30 18:59:13 2022 +0100
[SYSTEMDS-3443] Rename triggerremote instruction to checkpoint_e
This patch extends the checkpoint instruction with an asynchronous
checkpoint_e (eager checkpoint) version. This operator eagerly triggers
a chain of Spark operations and persist the distributed results.
Closes #1744
---
.../apache/sysds/conf/ConfigurationManager.java | 4 +
src/main/java/org/apache/sysds/conf/DMLConfig.java | 5 +-
.../java/org/apache/sysds/hops/OptimizerUtils.java | 1 +
.../java/org/apache/sysds/lops/Checkpoint.java | 17 ++++-
.../java/org/apache/sysds/lops/compile/Dag.java | 86 ----------------------
.../lops/compile/linearization/ILinearize.java | 49 ++++++++++++
.../runtime/instructions/CPInstructionParser.java | 4 -
.../runtime/instructions/SPInstructionParser.java | 3 +-
...rationsTask.java => TriggerCheckpointTask.java} | 7 +-
.../cp/TriggerRemoteOpsCPInstruction.java | 53 -------------
.../spark/CheckpointSPInstruction.java | 19 ++++-
.../apache/sysds/utils/stats/SparkStatistics.java | 16 ++--
12 files changed, 103 insertions(+), 161 deletions(-)
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index bb6172993a..12764eacf4 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -247,6 +247,10 @@ public class ConfigurationManager
return
(getDMLConfig().getBooleanValue(DMLConfig.ASYNC_SPARK_BROADCAST)
|| OptimizerUtils.ASYNC_BROADCAST_SPARK);
}
+ public static boolean isCheckpointEnabled() {
+ return
(getDMLConfig().getBooleanValue(DMLConfig.ASYNC_SPARK_CHECKPOINT)
+ || OptimizerUtils.ASYNC_CHECKPOINT_SPARK);
+ }
public static ILinearize.DagLinearization getLinearizationOrder() {
if (OptimizerUtils.MAX_PARALLELIZE_ORDER)
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index 2225d89d89..dad670efd4 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -130,6 +130,7 @@ public class DMLConfig
/** Asynchronous triggering of Spark OPs and operator placement **/
public static final String ASYNC_SPARK_PREFETCH =
"sysds.async.prefetch"; // boolean: enable asynchronous prefetching spark
intermediates
public static final String ASYNC_SPARK_BROADCAST =
"sysds.async.broadcast"; // boolean: enable asynchronous broadcasting CP
intermediates
+ public static final String ASYNC_SPARK_CHECKPOINT =
"sysds.async.checkpoint"; // boolean: enable asynchronous persisting of Spark
intermediates
//internal config
public static final String DEFAULT_SHARED_DIR_PERMISSION = "777"; //for
local fs and DFS
@@ -202,6 +203,7 @@ public class DMLConfig
_defaultVals.put(PRIVACY_CONSTRAINT_MOCK, null);
_defaultVals.put(ASYNC_SPARK_PREFETCH, "false" );
_defaultVals.put(ASYNC_SPARK_BROADCAST, "false" );
+ _defaultVals.put(ASYNC_SPARK_CHECKPOINT, "false" );
}
public DMLConfig() {
@@ -454,7 +456,8 @@ public class DMLConfig
PRINT_GPU_MEMORY_INFO, AVAILABLE_GPUS, SYNCHRONIZE_GPU,
EAGER_CUDA_FREE, FLOATING_POINT_PRECISION,
GPU_EVICTION_POLICY, LOCAL_SPARK_NUM_THREADS,
EVICTION_SHADOW_BUFFERSIZE, GPU_MEMORY_ALLOCATOR,
GPU_MEMORY_UTILIZATION_FACTOR,
USE_SSL_FEDERATED_COMMUNICATION, DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT,
- FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY,
ASYNC_SPARK_PREFETCH, ASYNC_SPARK_BROADCAST
+ FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY,
ASYNC_SPARK_PREFETCH, ASYNC_SPARK_BROADCAST,
+ ASYNC_SPARK_CHECKPOINT
};
StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index d2e9670362..580ddccf20 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -284,6 +284,7 @@ public class OptimizerUtils
*/
public static boolean ASYNC_PREFETCH_SPARK = false;
public static boolean ASYNC_BROADCAST_SPARK = false;
+ public static boolean ASYNC_CHECKPOINT_SPARK = false;
/**
* Heuristic-based instruction ordering to maximize inter-operator
parallelism.
diff --git a/src/main/java/org/apache/sysds/lops/Checkpoint.java
b/src/main/java/org/apache/sysds/lops/Checkpoint.java
index a15078bc70..6f56d287c0 100644
--- a/src/main/java/org/apache/sysds/lops/Checkpoint.java
+++ b/src/main/java/org/apache/sysds/lops/Checkpoint.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import java.util.Arrays;
/**
* Lop for checkpoint operations. For example, on Spark, the semantic of a
checkpoint
@@ -38,13 +39,15 @@ import
org.apache.sysds.runtime.instructions.InstructionUtils;
*/
public class Checkpoint extends Lop
{
- public static final String OPCODE = "chkpoint";
-
+ public static final String DEFAULT_CP_OPCODE = "chkpoint";
+ public static final String ASYNC_CP_OPCODE = "chkpoint_e";
+
public static final StorageLevel DEFAULT_STORAGE_LEVEL =
StorageLevel.MEMORY_AND_DISK();
public static final StorageLevel SER_STORAGE_LEVEL =
StorageLevel.MEMORY_AND_DISK_SER();
public static final boolean CHECKPOINT_SPARSE_CSR = true;
private StorageLevel _storageLevel;
+ private boolean _async = false;
/**
@@ -55,16 +58,22 @@ public class Checkpoint extends Lop
* @param dt data type
* @param vt value type
* @param level storage level
+ * @param isAsync true if eager and asynchronous checkpoint
*/
- public Checkpoint(Lop input, DataType dt, ValueType vt, String level) {
+ public Checkpoint(Lop input, DataType dt, ValueType vt, String level,
boolean isAsync) {
super(Lop.Type.Checkpoint, dt, vt);
addInput(input);
input.addOutput(this);
_storageLevel = StorageLevel.fromString(level);
+ _async = isAsync;
lps.setProperties(inputs, ExecType.SPARK);
}
+ public Checkpoint(Lop input, DataType dt, ValueType vt, String level) {
+ this(input, dt, vt, level, false);
+ }
+
public StorageLevel getStorageLevel()
{
return _storageLevel;
@@ -89,7 +98,7 @@ public class Checkpoint extends Lop
return InstructionUtils.concatOperands(
getExecType().name(),
- OPCODE,
+ _async ? ASYNC_CP_OPCODE : DEFAULT_CP_OPCODE,
getInputs().get(0).prepInputOperand(input1),
prepOutputOperand(output),
getStorageLevelString(_storageLevel));
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 2efbea8221..77325fb297 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -237,56 +237,6 @@ public class Dag<N extends Lop>
}
}
- private static List<Lop> addPrefetchLop(List<Lop> nodes) {
- List<Lop> nodesWithPrefetch = new ArrayList<>();
-
- //Find the Spark nodes with all CP outputs
- for (Lop l : nodes) {
- nodesWithPrefetch.add(l);
- if (isPrefetchNeeded(l)) {
- //TODO: No prefetch if the parent is placed
right after the spark OP
- //or push the parent further to increase
parallelism
- List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
- //Construct a Prefetch lop that takes this
Spark node as a input
- UnaryCP prefetch = new UnaryCP(l,
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
- for (Lop outCP : oldOuts) {
- //Rewire l -> outCP to l -> Prefetch ->
outCP
- prefetch.addOutput(outCP);
- outCP.replaceInput(l, prefetch);
- l.removeOutput(outCP);
- //FIXME: Rewire _inputParams when
needed (e.g. GroupedAggregate)
- }
- //Place it immediately after the Spark lop in
the node list
- nodesWithPrefetch.add(prefetch);
- }
- }
- return nodesWithPrefetch;
- }
-
- private static List<Lop> addBroadcastLop(List<Lop> nodes) {
- List<Lop> nodesWithBroadcast = new ArrayList<>();
-
- for (Lop l : nodes) {
- nodesWithBroadcast.add(l);
- if (isBroadcastNeeded(l)) {
- List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
- //Construct a Broadcast lop that takes this
Spark node as an input
- UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST,
l.getDataType(), l.getValueType(), ExecType.CP);
- //FIXME: Wire Broadcast only with the necessary
outputs
- for (Lop outCP : oldOuts) {
- //Rewire l -> outCP to l -> Broadcast
-> outCP
- bc.addOutput(outCP);
- outCP.replaceInput(l, bc);
- l.removeOutput(outCP);
- //FIXME: Rewire _inputParams when
needed (e.g. GroupedAggregate)
- }
- //Place it immediately after the Spark lop in
the node list
- nodesWithBroadcast.add(bc);
- }
- }
- return nodesWithBroadcast;
- }
-
private ArrayList<Instruction> doPlainInstructionGen(StatementBlock sb,
List<Lop> nodes)
{
//prepare basic instruction sets
@@ -319,42 +269,6 @@ public class Dag<N extends Lop>
&&
dnode.getOutputParameters().getLabel().equals(input.getOutputParameters().getLabel());
}
- private static boolean isPrefetchNeeded(Lop lop) {
- // Run Prefetch for a Spark instruction if the instruction is a
Transformation
- // and the output is consumed by only CP instructions.
- boolean transformOP = lop.getExecType() == ExecType.SPARK &&
lop.getAggType() != SparkAggType.SINGLE_BLOCK
- // Always Action operations
- && !(lop.getDataType() == DataType.SCALAR)
- && !(lop instanceof MapMultChain) && !(lop
instanceof PickByCount)
- && !(lop instanceof MMZip) && !(lop instanceof
CentralMoment)
- && !(lop instanceof CoVariance)
- // Not qualified for prefetching
- && !(lop instanceof Checkpoint) && !(lop
instanceof ReBlock)
- && !(lop instanceof CSVReBlock)
- // Cannot filter Transformation cases from
Actions (FIXME)
- && !(lop instanceof MMTSJ) && !(lop instanceof
UAggOuterChain)
- && !(lop instanceof ParameterizedBuiltin) &&
!(lop instanceof SpoofFused);
-
- //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
- boolean hasParameterizedOut = lop.getOutputs().stream()
- .anyMatch(out -> ((out instanceof
ParameterizedBuiltin)
- || (out instanceof GroupedAggregate)
- || (out instanceof GroupedAggregateM)));
- //TODO: support non-matrix outputs
- return transformOP && !hasParameterizedOut
- && lop.isAllOutputsCP() && lop.getDataType() ==
DataType.MATRIX;
- }
-
- private static boolean isBroadcastNeeded(Lop lop) {
- // Asynchronously broadcast a matrix if that is produced by a
CP instruction,
- // and at least one Spark parent needs to broadcast this
intermediate (eg. mapmm)
- boolean isBc = lop.getOutputs().stream()
- .anyMatch(out -> (out.getBroadcastInput() ==
lop));
- //TODO: Early broadcast objects that are bigger than a single
block
- //return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
- return isBc && lop.getDataType() == DataType.MATRIX;
- }
-
private static List<Instruction>
deleteUpdatedTransientReadVariables(StatementBlock sb, List<Lop> nodeV) {
List<Instruction> insts = new ArrayList<>();
if ( sb == null ) //return modifiable list
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index d867a91f4a..e78530b33b 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -319,6 +319,31 @@ public interface ILinearize {
return nodesWithBroadcast;
}
+ private static List<Lop> addAsyncEagerCheckpointLop(List<Lop> nodes) {
+ List<Lop> nodesWithCheckpoint = new ArrayList<>();
+ // Find the Spark action nodes
+ for (Lop l : nodes) {
+ if (isCheckpointNeeded(l)) {
+ List<Lop> oldInputs = new
ArrayList<>(l.getInputs());
+ // Place a Checkpoint node just below this node
(Spark action)
+ for (Lop in : oldInputs) {
+ if (in.getExecType() != ExecType.SPARK)
+ continue;
+ // Rewire in -> l to in -> Checkpoint
-> l
+ //UnaryCP checkpoint = new UnaryCP(in,
OpOp1.TRIGREMOTE, in.getDataType(), in.getValueType(), ExecType.CP);
+ Lop checkpoint = new Checkpoint(in,
in.getDataType(), in.getValueType(),
+
Checkpoint.getDefaultStorageLevelString(), true);
+ checkpoint.addOutput(l);
+ l.replaceInput(in, checkpoint);
+ in.removeOutput(l);
+ nodesWithCheckpoint.add(checkpoint);
+ }
+ }
+ nodesWithCheckpoint.add(l);
+ }
+ return nodesWithCheckpoint;
+ }
+
private static boolean isPrefetchNeeded(Lop lop) {
// Run Prefetch for a Spark instruction if the instruction is a
Transformation
// and the output is consumed by only CP instructions.
@@ -354,4 +379,28 @@ public interface ILinearize {
//return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
return isBc && lop.getDataType() == DataType.MATRIX;
}
+
+ private static boolean isCheckpointNeeded(Lop lop) {
+ // Place checkpoint_e just before a Spark action (FIXME)
+ boolean actionOP = lop.getExecType() == ExecType.SPARK
+ && ((lop.getAggType() ==
SparkAggType.SINGLE_BLOCK)
+ // Always Action operations
+ || (lop.getDataType() == DataType.SCALAR)
+ || (lop instanceof MapMultChain) || (lop
instanceof PickByCount)
+ || (lop instanceof MMZip) || (lop instanceof
CentralMoment)
+ || (lop instanceof CoVariance) || (lop
instanceof MMTSJ))
+ // Not qualified for Checkpoint
+ && !(lop instanceof Checkpoint) && !(lop
instanceof ReBlock)
+ && !(lop instanceof CSVReBlock)
+ // Cannot filter Transformation cases from
Actions (FIXME)
+ && !(lop instanceof UAggOuterChain)
+ && !(lop instanceof ParameterizedBuiltin) &&
!(lop instanceof SpoofFused);
+
+ //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
+ boolean hasParameterizedOut = lop.getOutputs().stream()
+ .anyMatch(out -> ((out instanceof
ParameterizedBuiltin)
+ || (out instanceof GroupedAggregate)
+ || (out instanceof GroupedAggregateM)));
+ return actionOP && !hasParameterizedOut;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index b83a78674b..78dec00c24 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -67,7 +67,6 @@ import
org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
import org.apache.sysds.runtime.instructions.cp.SqlCPInstruction;
import org.apache.sysds.runtime.instructions.cp.StringInitCPInstruction;
import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.TriggerRemoteOpsCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -482,9 +481,6 @@ public class CPInstructionParser extends InstructionParser
case Broadcast:
return
BroadcastCPInstruction.parseInstruction(str);
- case TrigRemote:
- return
TriggerRemoteOpsCPInstruction.parseInstruction(str);
-
default:
throw new DMLRuntimeException("Invalid CP
Instruction Type: " + cptype );
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 0ff745de57..773153d6d4 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -240,7 +240,8 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put("libsvmrblk",
SPType.LIBSVMReblock);
// Spark-specific instructions
- String2SPInstructionType.put( Checkpoint.OPCODE,
SPType.Checkpoint);
+ String2SPInstructionType.put( Checkpoint.DEFAULT_CP_OPCODE,
SPType.Checkpoint);
+ String2SPInstructionType.put( Checkpoint.ASYNC_CP_OPCODE,
SPType.Checkpoint);
String2SPInstructionType.put( Compression.OPCODE,
SPType.Compression);
String2SPInstructionType.put( DeCompression.OPCODE,
SPType.DeCompression);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerCheckpointTask.java
similarity index 89%
rename from
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
rename to
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerCheckpointTask.java
index 63e6f56fd5..657460f62e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerCheckpointTask.java
@@ -25,10 +25,10 @@ import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.utils.stats.SparkStatistics;
-public class TriggerRemoteOperationsTask implements Runnable {
+public class TriggerCheckpointTask implements Runnable {
MatrixObject _remoteOperationsRoot;
- public TriggerRemoteOperationsTask(MatrixObject mo) {
+ public TriggerCheckpointTask(MatrixObject mo) {
_remoteOperationsRoot = mo;
}
@@ -36,6 +36,7 @@ public class TriggerRemoteOperationsTask implements Runnable {
public void run() {
boolean triggered = false;
synchronized (_remoteOperationsRoot) {
+ // FIXME: Handle double execution
if (_remoteOperationsRoot.isPendingRDDOps()) {
JavaPairRDD<?, ?> rdd =
_remoteOperationsRoot.getRDDHandle().getRDD();
rdd.persist(Checkpoint.DEFAULT_STORAGE_LEVEL).count();
@@ -45,6 +46,6 @@ public class TriggerRemoteOperationsTask implements Runnable {
}
if (DMLScript.STATISTICS && triggered)
- SparkStatistics.incAsyncTriggerRemoteCount(1);
+ SparkStatistics.incAsyncTriggerCheckpointCount(1);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
deleted file mode 100644
index 98d7f440c4..0000000000
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.sysds.runtime.instructions.cp;
-
-import java.util.concurrent.Executors;
-
-import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.matrix.operators.Operator;
-import org.apache.sysds.runtime.util.CommonThreadPool;
-
-public class TriggerRemoteOpsCPInstruction extends UnaryCPInstruction {
- private TriggerRemoteOpsCPInstruction(Operator op, CPOperand in,
CPOperand out, String opcode, String istr) {
- super(CPType.TrigRemote, op, in, out, opcode, istr);
- }
-
- public static TriggerRemoteOpsCPInstruction parseInstruction (String
str) {
- InstructionUtils.checkNumFields(str, 2);
- String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
- String opcode = parts[0];
- CPOperand in = new CPOperand(parts[1]);
- CPOperand out = new CPOperand(parts[2]);
- return new TriggerRemoteOpsCPInstruction(null, in, out, opcode,
str);
- }
-
- @Override
- public void processInstruction(ExecutionContext ec) {
- // TODO: Operator placement.
- // Note for testing: write a method in the Dag class to place
this operator
- // after Spark MMRJ. Then execute
PrefetchRDDTest.testAsyncSparkOPs3.
- ec.setVariable(output.getName(), ec.getMatrixObject(input1));
-
- if (CommonThreadPool.triggerRemoteOPsPool == null)
- CommonThreadPool.triggerRemoteOPsPool =
Executors.newCachedThreadPool();
- CommonThreadPool.triggerRemoteOPsPool.submit(new
TriggerRemoteOperationsTask(ec.getMatrixObject(output)));
- }
-}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
index f305630891..56455474b7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
@@ -35,6 +35,7 @@ import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.TriggerCheckpointTask;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import
org.apache.sysds.runtime.instructions.spark.functions.CopyFrameBlockFunction;
import
org.apache.sysds.runtime.instructions.spark.functions.CreateSparseBlockFunction;
@@ -43,9 +44,12 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
+import java.util.concurrent.Executors;
+
public class CheckpointSPInstruction extends UnarySPInstruction {
// default storage level
private StorageLevel _level = null;
@@ -71,7 +75,20 @@ public class CheckpointSPInstruction extends
UnarySPInstruction {
@SuppressWarnings("unchecked")
public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext)ec;
-
+
+ // Asynchronously trigger count() and persist this RDD
+ // TODO: Synchronize. Avoid double execution
+ if (getOpcode().equals("chkpoint_e")) { //eager checkpoint
+ // Inplace replace output matrix object with the input
matrix object
+ // We will never use the output of the Spark count call
+ ec.setVariable(output.getName(),
ec.getCacheableData(input1));
+
+ if (CommonThreadPool.triggerRemoteOPsPool == null)
+ CommonThreadPool.triggerRemoteOPsPool =
Executors.newCachedThreadPool();
+ CommonThreadPool.triggerRemoteOPsPool.submit(new
TriggerCheckpointTask(ec.getMatrixObject(output)));
+ return;
+ }
+
// Step 1: early abort on non-existing or in-memory (cached)
inputs
// -------
// (checkpoints are generated for all read only variables in
loops; due to unbounded scoping and
diff --git a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
index 331634d1bd..0d1b5b05b9 100644
--- a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
@@ -33,7 +33,7 @@ public class SparkStatistics {
private static final LongAdder broadcastCount = new LongAdder();
private static final LongAdder asyncPrefetchCount = new LongAdder();
private static final LongAdder asyncBroadcastCount = new LongAdder();
- private static final LongAdder asyncTriggerRemoteCount = new
LongAdder();
+ private static final LongAdder asyncTriggerCheckpointCount = new
LongAdder();
public static boolean createdSparkContext() {
return ctxCreateTime > 0;
@@ -76,8 +76,8 @@ public class SparkStatistics {
asyncBroadcastCount.add(c);
}
- public static void incAsyncTriggerRemoteCount(long c) {
- asyncTriggerRemoteCount.add(c);
+ public static void incAsyncTriggerCheckpointCount(long c) {
+ asyncTriggerCheckpointCount.add(c);
}
public static long getSparkCollectCount() {
@@ -92,8 +92,8 @@ public class SparkStatistics {
return asyncBroadcastCount.longValue();
}
- public static long getAsyncTriggerRemoteCount() {
- return asyncTriggerRemoteCount.longValue();
+ public static long getasyncTriggerCheckpointCount() {
+ return asyncTriggerCheckpointCount.longValue();
}
public static void reset() {
@@ -106,7 +106,7 @@ public class SparkStatistics {
collectCount.reset();
asyncPrefetchCount.reset();
asyncBroadcastCount.reset();
- asyncTriggerRemoteCount.reset();
+ asyncTriggerCheckpointCount.reset();
}
public static String displayStatistics() {
@@ -122,8 +122,8 @@ public class SparkStatistics {
parallelizeTime.longValue()*1e-9,
broadcastTime.longValue()*1e-9,
collectTime.longValue()*1e-9));
- sb.append("Spark async. count (pf,bc,tr): \t" +
- String.format("%d/%d/%d.\n",
getAsyncPrefetchCount(), getAsyncBroadcastCount(),
getAsyncTriggerRemoteCount()));
+ sb.append("Spark async. count (pf,bc,cp): \t" +
+ String.format("%d/%d/%d.\n",
getAsyncPrefetchCount(), getAsyncBroadcastCount(),
getasyncTriggerCheckpointCount()));
return sb.toString();
}
}