This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c674b77e67 [SYSTEMDS-3479] Persist and reuse of Spark RDDs
c674b77e67 is described below
commit c674b77e673d6698af6be2f72ea5c44d31210b29
Author: Arnab Phani <[email protected]>
AuthorDate: Fri Dec 16 23:30:18 2022 +0100
[SYSTEMDS-3479] Persist and reuse of Spark RDDs
This patch extends lineage cache to store the RDD objects
which are checkpointed. This addition allows the compiler
to place a chkpoint after a potentially redundant operator.
During runtime, we then persist the RDD, save the RDD in the
lineage cache, and reuse if the instruction repeats. It is a
bit tricky to cache a RDD for a lineage trace of a previous
instruction. A better way would be to be able to mark any
instruction to persist the result RDD and skip the chkpoint
instruction.
Hyperparameter tuning of LmDS with 2.5k columns improves by
4x by caching the cpmm results in the executors.
Closes #1751
---
.../context/SparkExecutionContext.java | 5 ++
.../spark/CheckpointSPInstruction.java | 11 ++++-
.../apache/sysds/runtime/lineage/LineageCache.java | 53 +++++++++++++++++++++-
.../sysds/runtime/lineage/LineageCacheConfig.java | 18 ++++++--
.../sysds/runtime/lineage/LineageCacheEntry.java | 35 ++++++++++++--
.../multitenant/FederatedReuseSlicesTest.java | 3 +-
.../FederatedSerializationReuseTest.java | 3 +-
7 files changed, 117 insertions(+), 11 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 4118eee170..48778cb4d4 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -889,6 +889,11 @@ public class SparkExecutionContext extends ExecutionContext
obj.setRDDHandle( rddhandle );
}
+ public void setRDDHandleForVariable(String varname, RDDObject
rddhandle) {
+ CacheableData<?> obj = getCacheableData(varname);
+ obj.setRDDHandle(rddhandle);
+ }
+
public static JavaPairRDD<MatrixIndexes,MatrixBlock>
toMatrixJavaPairRDD(JavaSparkContext sc, MatrixBlock src, int blen) {
return toMatrixJavaPairRDD(sc, src, blen, -1, true);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
index 56455474b7..16d8c9d228 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
@@ -40,6 +40,7 @@ import
org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import
org.apache.sysds.runtime.instructions.spark.functions.CopyFrameBlockFunction;
import
org.apache.sysds.runtime.instructions.spark.functions.CreateSparseBlockFunction;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -77,10 +78,10 @@ public class CheckpointSPInstruction extends
UnarySPInstruction {
SparkExecutionContext sec = (SparkExecutionContext)ec;
// Asynchronously trigger count() and persist this RDD
- // TODO: Synchronize. Avoid double execution
if (getOpcode().equals("chkpoint_e")) { //eager checkpoint
// Inplace replace output matrix object with the input
matrix object
// We will never use the output of the Spark count call
+ // TODO: Synchronize. Avoid double execution
ec.setVariable(output.getName(),
ec.getCacheableData(input1));
if (CommonThreadPool.triggerRemoteOPsPool == null)
@@ -100,6 +101,14 @@ public class CheckpointSPInstruction extends
UnarySPInstruction {
sec.setVariable( output.getName(), new
BooleanObject(false));
return;
}
+
+ if (!LineageCacheConfig.ReuseCacheType.isNone() &&
sec.getCacheableData(input1).getRDDHandle() != null
+ &&
sec.getCacheableData(input1).getRDDHandle().isCheckpointRDD()) {
+ // Do nothing if the RDD is already checkpointed
+ sec.setVariable(output.getName(),
sec.getCacheableData(input1.getName()));
+ Statistics.decrementNoOfExecutedSPInst();
+ return;
+ }
//-------
//(for csv input files with unknown dimensions, we might have
generated a checkpoint after
//csvreblock although not necessary because the csvreblock was
subject to in-memory reblock)
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 5e35181a04..d4eb4b8f92 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -31,9 +31,11 @@ import org.apache.sysds.lops.MMTSJ.MMTSJType;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.MatrixObjectFuture;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
@@ -51,7 +53,9 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
+import org.apache.sysds.runtime.instructions.spark.CheckpointSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -183,6 +187,14 @@ public class LineageCache
else
ec.setScalarOutput(outName, so);
}
+ else if (e.isRDDPersist()) {
+ //Reuse the RDD which is also
persisted in Spark
+ RDDObject rdd =
e.getRDDObject();
+ if (rdd == null &&
e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
+ return false;
+ else
+
((SparkExecutionContext)ec).setRDDHandleForVariable(outName, rdd);
+ }
else { //TODO handle locks on gpu
objects
//shallow copy the cached
GPUObj to the output MatrixObject
ec.getMatrixObject(outName).setGPUObject(ec.getGPUContext(0),
@@ -520,7 +532,9 @@ public class LineageCache
//if (!isMarkedForCaching(inst, ec)) return;
List<Pair<LineageItem, Data>> liData = null;
GPUObject liGpuObj = null;
+ RDDObject rddObj = null;
LineageItem instLI = ((LineageTraceable)
inst).getLineageItem(ec).getValue();
+ LineageItem instInputLI = null;
if (inst instanceof MultiReturnBuiltinCPInstruction) {
liData = new ArrayList<>();
MultiReturnBuiltinCPInstruction mrInst =
(MultiReturnBuiltinCPInstruction)inst;
@@ -542,6 +556,15 @@ public class LineageCache
if (liGpuObj == null)
liData = Arrays.asList(Pair.of(instLI,
ec.getVariable(((GPUInstruction)inst)._output)));
}
+ else if (inst instanceof CheckpointSPInstruction) {
+ // Get the lineage of the instruction being
checkpointed
+ instInputLI =
ec.getLineageItem(((ComputationSPInstruction)inst).input1);
+ // Get the RDD handle of the persisted RDD
+ CacheableData<?> cd =
ec.getCacheableData(((ComputationSPInstruction)inst).output.getName());
+ rddObj = ((CacheableData<?>) cd).getRDDHandle();
+ // Remove the lineage item of the chkpoint
instruction
+ removePlaceholder(instLI);
+ }
else
if (inst instanceof ComputationCPInstruction)
liData = Arrays.asList(Pair.of(instLI,
ec.getVariable(((ComputationCPInstruction) inst).output)));
@@ -550,10 +573,12 @@ public class LineageCache
else if (inst instanceof
ComputationSPInstruction)
liData = Arrays.asList(Pair.of(instLI,
ec.getVariable(((ComputationSPInstruction) inst).output)));
- if (liGpuObj == null)
+ if (liGpuObj == null && rddObj == null)
putValueCPU(inst, liData, computetime);
- else
+ if (liGpuObj != null)
putValueGPU(liGpuObj, instLI, computetime);
+ if (rddObj != null)
+ putValueRDD(rddObj, instInputLI, computetime);
}
}
@@ -582,6 +607,13 @@ public class LineageCache
continue;
}
+ if (LineageCacheConfig.isToPersist(inst) &&
LineageCacheConfig.getCompAssRW()) {
+ // The immediately following
instruction must be a checkpoint, which will
+ // fill the rdd in this cache entry.
+ // TODO: Instead check if this
instruction is marked for checkpointing
+ continue;
+ }
+
if (data instanceof MatrixObject &&
((MatrixObject) data).hasRDDHandle()) {
// Avoid triggering pre-matured Spark
instruction chains
removePlaceholder(item);
@@ -640,6 +672,23 @@ public class LineageCache
}
}
+ private static void putValueRDD(RDDObject rdd, LineageItem instLI, long
computetime) {
+ synchronized( _cache ) {
+ // Not available in the cache indicates this RDD is not
marked for caching
+ if (!probe(instLI))
+ return;
+
+ LineageCacheEntry centry = _cache.get(instLI);
+ if (centry.isRDDPersist() &&
centry.getRDDObject().isCheckpointRDD())
+ // Do nothing if the cached RDD is already
checkpointed
+ return;
+
+ centry.setRDDValue(rdd, computetime);
+ // Maintain order for eviction
+ LineageCacheEviction.addEntry(centry);
+ }
+ }
+
// This method is called from inside the asynchronous operators and
directly put the output of
// an asynchronous instruction into the lineage cache. As the
consumers, a different operator,
// materializes the intermediate, we skip the placeholder placing logic.
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index e7761b1333..52b8399fcc 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -54,11 +54,18 @@ public class LineageCacheConfig
"^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand",
"replace",
"^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax",
"qsort",
"qpick", "transformapply", "uarmax", "n+", "-*", "castdtm",
"lowertri",
- "mapmm", "cpmm", "prefetch"
+ "mapmm", "cpmm", "rmm", "prefetch", "chkpoint"
//TODO: Reuse everything.
};
+
+ private static final String[] OPCODES_CP = new String[] {
+ "cpmm", "rmm"
+ //TODO: Instead mark an instruction to be checkpointed
+ };
+
private static String[] REUSE_OPCODES = new String[] {};
-
+ private static String[] OPCODES_CHECKPOINTS = new String[] {};
+
public enum ReuseCacheType {
REUSE_FULL,
REUSE_PARTIAL,
@@ -189,9 +196,10 @@ public class LineageCacheConfig
static {
//setup static configuration parameters
REUSE_OPCODES = OPCODES;
+ OPCODES_CHECKPOINTS = OPCODES_CP;
//setSpill(true);
setCachePolicy(LineageCachePolicy.COSTNSIZE);
- setCompAssRW(true);
+ setCompAssRW(false);
}
public static void setReusableOpcodes(String... ops) {
@@ -201,6 +209,10 @@ public class LineageCacheConfig
public static String[] getReusableOpcodes() {
return REUSE_OPCODES;
}
+
+ public static boolean isToPersist(Instruction inst) {
+ return ArrayUtils.contains(OPCODES_CHECKPOINTS,
inst.getOpcode());
+ }
public static void resetReusableOpcodes() {
REUSE_OPCODES = OPCODES;
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index b8e30cb4c3..0042674e56 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
+import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -42,6 +43,7 @@ public class LineageCacheEntry {
private String _outfile = null;
protected double score;
protected GPUObject _gpuObject;
+ protected RDDObject _rddObject;
public LineageCacheEntry(LineageItem key, DataType dt, MatrixBlock
Mval, ScalarObject Sval, long computetime) {
_key = key;
@@ -90,6 +92,21 @@ public class LineageCacheEntry {
}
}
+ public synchronized RDDObject getRDDObject() {
+ try {
+ //wait until other thread completes operation
+ //in order to avoid redundant computation
+ while(_status == LineageCacheStatus.EMPTY) {
+ wait();
+ }
+ //comes here if data is placed or the entry is removed
by the running thread
+ return _rddObject;
+ }
+ catch( InterruptedException ex ) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
public synchronized byte[] getSerializedBytes() {
try {
// wait until other thread completes operation
@@ -129,15 +146,19 @@ public class LineageCacheEntry {
}
public boolean isNullVal() {
- return(_MBval == null && _SOval == null && _gpuObject == null
&& _serialBytes == null);
+ return(_MBval == null && _SOval == null && _gpuObject == null
&& _serialBytes == null && _rddObject == null);
}
public boolean isMatrixValue() {
- return _dt.isMatrix();
+ return _dt.isMatrix() && _rddObject == null;
}
public boolean isScalarValue() {
- return _dt.isScalar();
+ return _dt.isScalar() && _rddObject == null;
+ }
+
+ public boolean isRDDPersist() {
+ return _rddObject != null;
}
public boolean isSerializedBytes() {
@@ -175,6 +196,14 @@ public class LineageCacheEntry {
notifyAll();
}
+ public synchronized void setRDDValue(RDDObject rdd, long computetime) {
+ _rddObject = rdd;
+ _computeTime = computetime;
+ _status = isNullVal() ? LineageCacheStatus.EMPTY :
LineageCacheStatus.CACHED;
+ //resume all threads waiting for val
+ notifyAll();
+ }
+
public synchronized void setValue(byte[] serialBytes, long computetime)
{
_serialBytes = serialBytes;
_computeTime = computetime;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
index fc045933c6..50d7f662a2 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
@@ -158,7 +158,8 @@ public class FederatedReuseSlicesTest extends
MultiTenantTestBase {
// start the coordinator processes
String scriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-config", CONFIG_DIR +
"SystemDS-MultiTenant-config.xml",
- "-lineage", "reuse", "-stats", "100", "-fedStats",
"100", "-nvargs",
+ //"-lineage", "reuse", "-stats", "100", "-fedStats",
"100", "-nvargs",
+ "-stats", "100", "-fedStats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(workerPorts[0],
input("X1")),
"in_X2=" + TestUtils.federatedAddress(workerPorts[1],
input("X2")),
"in_X3=" + TestUtils.federatedAddress(workerPorts[2],
input("X3")),
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
index 3b66a5c393..05f02846c0 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
@@ -159,7 +159,8 @@ public class FederatedSerializationReuseTest extends
MultiTenantTestBase {
// start the coordinator processes
String scriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-config", CONFIG_DIR +
"SystemDS-MultiTenant-config.xml",
- "-lineage", "reuse", "-stats", "100", "-fedStats",
"100", "-nvargs",
+ //"-lineage", "reuse", "-stats", "100", "-fedStats",
"100", "-nvargs",
+ "-stats", "100", "-fedStats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(workerPorts[0],
input("X1")),
"in_X2=" + TestUtils.federatedAddress(workerPorts[1],
input("X2")),
"in_X3=" + TestUtils.federatedAddress(workerPorts[2],
input("X3")),