This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c674b77e67 [SYSTEMDS-3479] Persist and reuse of Spark RDDs
c674b77e67 is described below

commit c674b77e673d6698af6be2f72ea5c44d31210b29
Author: Arnab Phani <[email protected]>
AuthorDate: Fri Dec 16 23:30:18 2022 +0100

    [SYSTEMDS-3479] Persist and reuse of Spark RDDs
    
    This patch extends lineage cache to store the RDD objects
    which are checkpointed. This addition allows the compiler
    to place a chkpoint after a potentially redundant operator.
    During runtime, we then persist the RDD, save the RDD in the
    lineage cache, and reuse if the instruction repeats. It is a
    bit tricky to cache a RDD for a lineage trace of a previous
    instruction. A better way would be to be able to mark any
    instruction to persist the result RDD and skip the chkpoint
    instruction.
    Hyperparameter tuning of LmDS with 2.5k columns improves by
    4x by caching the cpmm results in the executors.
    
    Closes #1751
---
 .../context/SparkExecutionContext.java             |  5 ++
 .../spark/CheckpointSPInstruction.java             | 11 ++++-
 .../apache/sysds/runtime/lineage/LineageCache.java | 53 +++++++++++++++++++++-
 .../sysds/runtime/lineage/LineageCacheConfig.java  | 18 ++++++--
 .../sysds/runtime/lineage/LineageCacheEntry.java   | 35 ++++++++++++--
 .../multitenant/FederatedReuseSlicesTest.java      |  3 +-
 .../FederatedSerializationReuseTest.java           |  3 +-
 7 files changed, 117 insertions(+), 11 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 4118eee170..48778cb4d4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -889,6 +889,11 @@ public class SparkExecutionContext extends ExecutionContext
                obj.setRDDHandle( rddhandle );
        }
 
+       public void setRDDHandleForVariable(String varname, RDDObject 
rddhandle) {
+               CacheableData<?> obj = getCacheableData(varname);
+               obj.setRDDHandle(rddhandle);
+       }
+
        public static JavaPairRDD<MatrixIndexes,MatrixBlock> 
toMatrixJavaPairRDD(JavaSparkContext sc, MatrixBlock src, int blen) {
                return toMatrixJavaPairRDD(sc, src, blen, -1, true);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
index 56455474b7..16d8c9d228 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
@@ -40,6 +40,7 @@ import 
org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import 
org.apache.sysds.runtime.instructions.spark.functions.CopyFrameBlockFunction;
 import 
org.apache.sysds.runtime.instructions.spark.functions.CreateSparseBlockFunction;
 import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -77,10 +78,10 @@ public class CheckpointSPInstruction extends 
UnarySPInstruction {
                SparkExecutionContext sec = (SparkExecutionContext)ec;
 
                // Asynchronously trigger count() and persist this RDD
-               // TODO: Synchronize. Avoid double execution
                if (getOpcode().equals("chkpoint_e")) {  //eager checkpoint
                        // Inplace replace output matrix object with the input 
matrix object
                        // We will never use the output of the Spark count call
+                       // TODO: Synchronize. Avoid double execution
                        ec.setVariable(output.getName(), 
ec.getCacheableData(input1));
 
                        if (CommonThreadPool.triggerRemoteOPsPool == null)
@@ -100,6 +101,14 @@ public class CheckpointSPInstruction extends 
UnarySPInstruction {
                        sec.setVariable( output.getName(), new 
BooleanObject(false));
                        return;
                }
+
+               if (!LineageCacheConfig.ReuseCacheType.isNone() && 
sec.getCacheableData(input1).getRDDHandle() != null
+                       && 
sec.getCacheableData(input1).getRDDHandle().isCheckpointRDD()) {
+                       // Do nothing if the RDD is already checkpointed
+                       sec.setVariable(output.getName(), 
sec.getCacheableData(input1.getName()));
+                       Statistics.decrementNoOfExecutedSPInst();
+                       return;
+               }
                //-------
                //(for csv input files with unknown dimensions, we might have 
generated a checkpoint after
                //csvreblock although not necessary because the csvreblock was 
subject to in-memory reblock)
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 5e35181a04..d4eb4b8f92 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -31,9 +31,11 @@ import org.apache.sysds.lops.MMTSJ.MMTSJType;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.MatrixObjectFuture;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
@@ -51,7 +53,9 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
 import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
+import org.apache.sysds.runtime.instructions.spark.CheckpointSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -183,6 +187,14 @@ public class LineageCache
                                                else
                                                        
ec.setScalarOutput(outName, so);
                                        }
+                                       else if (e.isRDDPersist()) {
+                                               //Reuse the RDD which is also 
persisted in Spark
+                                               RDDObject rdd = 
e.getRDDObject();
+                                               if (rdd == null && 
e.getCacheStatus() == LineageCacheStatus.NOTCACHED)
+                                                       return false;
+                                               else
+                                                       
((SparkExecutionContext)ec).setRDDHandleForVariable(outName, rdd);
+                                       }
                                        else { //TODO handle locks on gpu 
objects
                                                //shallow copy the cached 
GPUObj to the output MatrixObject
                                                
ec.getMatrixObject(outName).setGPUObject(ec.getGPUContext(0), 
@@ -520,7 +532,9 @@ public class LineageCache
                        //if (!isMarkedForCaching(inst, ec)) return;
                        List<Pair<LineageItem, Data>> liData = null;
                        GPUObject liGpuObj = null;
+                       RDDObject rddObj = null;
                        LineageItem instLI = ((LineageTraceable) 
inst).getLineageItem(ec).getValue();
+                       LineageItem instInputLI = null;
                        if (inst instanceof MultiReturnBuiltinCPInstruction) {
                                liData = new ArrayList<>();
                                MultiReturnBuiltinCPInstruction mrInst = 
(MultiReturnBuiltinCPInstruction)inst;
@@ -542,6 +556,15 @@ public class LineageCache
                                if (liGpuObj == null)
                                        liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((GPUInstruction)inst)._output)));
                        }
+                       else if (inst instanceof CheckpointSPInstruction) {
+                               // Get the lineage of the instruction being 
checkpointed
+                               instInputLI = 
ec.getLineageItem(((ComputationSPInstruction)inst).input1);
+                               // Get the RDD handle of the persisted RDD
+                               CacheableData<?> cd = 
ec.getCacheableData(((ComputationSPInstruction)inst).output.getName());
+                               rddObj = ((CacheableData<?>) cd).getRDDHandle();
+                               // Remove the lineage item of the chkpoint 
instruction
+                               removePlaceholder(instLI);
+                       }
                        else
                                if (inst instanceof ComputationCPInstruction)
                                        liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationCPInstruction) inst).output)));
@@ -550,10 +573,12 @@ public class LineageCache
                                else if (inst instanceof 
ComputationSPInstruction)
                                        liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((ComputationSPInstruction) inst).output)));
 
-                       if (liGpuObj == null)
+                       if (liGpuObj == null && rddObj == null)
                                putValueCPU(inst, liData, computetime);
-                       else
+                       if (liGpuObj != null)
                                putValueGPU(liGpuObj, instLI, computetime);
+                       if (rddObj != null)
+                               putValueRDD(rddObj, instInputLI, computetime);
                }
        }
        
@@ -582,6 +607,13 @@ public class LineageCache
                                        continue;
                                }
 
+                               if (LineageCacheConfig.isToPersist(inst) && 
LineageCacheConfig.getCompAssRW()) {
+                                       // The immediately following 
instruction must be a checkpoint, which will
+                                       // fill the rdd in this cache entry.
+                                       // TODO: Instead check if this 
instruction is marked for checkpointing
+                                       continue;
+                               }
+
                                if (data instanceof MatrixObject && 
((MatrixObject) data).hasRDDHandle()) {
                                        // Avoid triggering pre-matured Spark 
instruction chains
                                        removePlaceholder(item);
@@ -640,6 +672,23 @@ public class LineageCache
                }
        }
 
+       private static void putValueRDD(RDDObject rdd, LineageItem instLI, long 
computetime) {
+               synchronized( _cache ) {
+                       // Not available in the cache indicates this RDD is not 
marked for caching
+                       if (!probe(instLI))
+                               return;
+
+                       LineageCacheEntry centry = _cache.get(instLI);
+                       if (centry.isRDDPersist() && 
centry.getRDDObject().isCheckpointRDD())
+                               // Do nothing if the cached RDD is already 
checkpointed
+                               return;
+
+                       centry.setRDDValue(rdd, computetime);
+                       // Maintain order for eviction
+                       LineageCacheEviction.addEntry(centry);
+               }
+       }
+
        // This method is called from inside the asynchronous operators and 
directly put the output of
        // an asynchronous instruction into the lineage cache. As the 
consumers, a different operator,
        // materializes the intermediate, we skip the placeholder placing logic.
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index e7761b1333..52b8399fcc 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -54,11 +54,18 @@ public class LineageCacheConfig
                "^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand", 
"replace",
                "^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax", 
"qsort", 
                "qpick", "transformapply", "uarmax", "n+", "-*", "castdtm", 
"lowertri",
-               "mapmm", "cpmm", "prefetch"
+               "mapmm", "cpmm", "rmm", "prefetch", "chkpoint"
                //TODO: Reuse everything. 
        };
+
+       private static final String[] OPCODES_CP = new String[] {
+               "cpmm", "rmm"
+               //TODO: Instead mark an instruction to be checkpointed
+       };
+
        private static String[] REUSE_OPCODES  = new String[] {};
-       
+       private static String[] OPCODES_CHECKPOINTS  = new String[] {};
+
        public enum ReuseCacheType {
                REUSE_FULL,
                REUSE_PARTIAL,
@@ -189,9 +196,10 @@ public class LineageCacheConfig
        static {
                //setup static configuration parameters
                REUSE_OPCODES = OPCODES;
+               OPCODES_CHECKPOINTS = OPCODES_CP;
                //setSpill(true); 
                setCachePolicy(LineageCachePolicy.COSTNSIZE);
-               setCompAssRW(true);
+               setCompAssRW(false);
        }
 
        public static void setReusableOpcodes(String... ops) {
@@ -201,6 +209,10 @@ public class LineageCacheConfig
        public static String[] getReusableOpcodes() {
                return REUSE_OPCODES;
        }
+
+       public static boolean isToPersist(Instruction inst) {
+               return ArrayUtils.contains(OPCODES_CHECKPOINTS, 
inst.getOpcode());
+       }
        
        public static void resetReusableOpcodes() {
                REUSE_OPCODES = OPCODES;
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index b8e30cb4c3..0042674e56 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
+import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
@@ -42,6 +43,7 @@ public class LineageCacheEntry {
        private String _outfile = null;
        protected double score;
        protected GPUObject _gpuObject;
+       protected RDDObject _rddObject;
        
        public LineageCacheEntry(LineageItem key, DataType dt, MatrixBlock 
Mval, ScalarObject Sval, long computetime) {
                _key = key;
@@ -90,6 +92,21 @@ public class LineageCacheEntry {
                }
        }
 
+       public synchronized RDDObject getRDDObject() {
+               try {
+                       //wait until other thread completes operation
+                       //in order to avoid redundant computation
+                       while(_status == LineageCacheStatus.EMPTY) {
+                               wait();
+                       }
+                       //comes here if data is placed or the entry is removed 
by the running thread
+                       return _rddObject;
+               }
+               catch( InterruptedException ex ) {
+                       throw new DMLRuntimeException(ex);
+               }
+       }
+
        public synchronized byte[] getSerializedBytes() {
                try {
                        // wait until other thread completes operation
@@ -129,15 +146,19 @@ public class LineageCacheEntry {
        }
        
        public boolean isNullVal() {
-               return(_MBval == null && _SOval == null && _gpuObject == null 
&& _serialBytes == null);
+               return(_MBval == null && _SOval == null && _gpuObject == null 
&& _serialBytes == null && _rddObject == null);
        }
        
        public boolean isMatrixValue() {
-               return _dt.isMatrix();
+               return _dt.isMatrix() && _rddObject == null;
        }
 
        public boolean isScalarValue() {
-               return _dt.isScalar();
+               return _dt.isScalar() && _rddObject == null;
+       }
+
+       public boolean isRDDPersist() {
+               return _rddObject != null;
        }
 
        public boolean isSerializedBytes() {
@@ -175,6 +196,14 @@ public class LineageCacheEntry {
                notifyAll();
        }
 
+       public synchronized void setRDDValue(RDDObject rdd, long computetime) {
+               _rddObject = rdd;
+               _computeTime = computetime;
+               _status = isNullVal() ? LineageCacheStatus.EMPTY : 
LineageCacheStatus.CACHED;
+               //resume all threads waiting for val
+               notifyAll();
+       }
+
        public synchronized void setValue(byte[] serialBytes, long computetime) 
{
                _serialBytes = serialBytes;
                _computeTime = computetime;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
index fc045933c6..50d7f662a2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
@@ -158,7 +158,8 @@ public class FederatedReuseSlicesTest extends 
MultiTenantTestBase {
                // start the coordinator processes
                String scriptName = HOME + TEST_NAME + ".dml";
                programArgs = new String[] {"-config", CONFIG_DIR + 
"SystemDS-MultiTenant-config.xml",
-                       "-lineage", "reuse", "-stats", "100", "-fedStats", 
"100", "-nvargs",
+                       //"-lineage", "reuse", "-stats", "100", "-fedStats", 
"100", "-nvargs",
+                       "-stats", "100", "-fedStats", "100", "-nvargs",
                        "in_X1=" + TestUtils.federatedAddress(workerPorts[0], 
input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(workerPorts[1], 
input("X2")),
                        "in_X3=" + TestUtils.federatedAddress(workerPorts[2], 
input("X3")),
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
index 3b66a5c393..05f02846c0 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
@@ -159,7 +159,8 @@ public class FederatedSerializationReuseTest extends 
MultiTenantTestBase {
                // start the coordinator processes
                String scriptName = HOME + TEST_NAME + ".dml";
                programArgs = new String[] {"-config", CONFIG_DIR + 
"SystemDS-MultiTenant-config.xml",
-                       "-lineage", "reuse", "-stats", "100", "-fedStats", 
"100", "-nvargs",
+                       //"-lineage", "reuse", "-stats", "100", "-fedStats", 
"100", "-nvargs",
+                       "-stats", "100", "-fedStats", "100", "-nvargs",
                        "in_X1=" + TestUtils.federatedAddress(workerPorts[0], 
input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(workerPorts[1], 
input("X2")),
                        "in_X3=" + TestUtils.federatedAddress(workerPorts[2], 
input("X3")),

Reply via email to