This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new ede3635a6c [SYSTEMDS-3443] Asynchronously execute and persist Spark 
transformations
ede3635a6c is described below

commit ede3635a6c0ba9d7089044eeec5317f685e9e03b
Author: Arnab Phani <[email protected]>
AuthorDate: Mon Oct 10 21:02:52 2022 +0200

    [SYSTEMDS-3443] Asynchronously execute and persist Spark transformations
    
    This patch adds an operator to asynchronously trigger a chain of Spark
    transformations and persist the result. This is a generalization of
    Prefetch instruction and works if the consumer is a Spark transformation
    or action. TODO: operator placement (in parallel with a CP instruction 
chain).
    
    Closes #1704
---
 src/main/java/org/apache/sysds/common/Types.java   |  2 +-
 .../runtime/instructions/CPInstructionParser.java  |  7 ++++-
 .../runtime/instructions/cp/CPInstruction.java     |  2 +-
 .../instructions/cp/PrefetchCPInstruction.java     |  2 +-
 ...perationsTask.java => TriggerPrefetchTask.java} |  8 +++---
 .../cp/TriggerRemoteOperationsTask.java            | 33 +++++++++-------------
 ...ion.java => TriggerRemoteOpsCPInstruction.java} | 21 ++++++--------
 .../apache/sysds/utils/stats/SparkStatistics.java  | 14 +++++++--
 .../test/functions/async/PrefetchRDDTest.java      |  2 +-
 src/test/scripts/functions/async/PrefetchRDD3.dml  |  6 ++--
 10 files changed, 52 insertions(+), 45 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index a7cfa823aa..a9b8108600 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -241,7 +241,7 @@ public class Types
                CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, 
INVERSE,
                IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
                MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, 
STOP,
-               SVD, TAN, TANH, TYPEOF,
+               SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
                //fused ML-specific operators for performance 
                SPROP, //sample proportion: P * (1 - P)
                SIGMOID, //sigmoid function: 1 / (1 + exp(-X))
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index b53c954fd8..f2d3080ddc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -67,6 +67,7 @@ import 
org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.SqlCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.StringInitCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.TriggerRemoteOpsCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -328,6 +329,7 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "spoof",     CPType.SpoofFused);
                String2CPInstructionType.put( "prefetch",  CPType.Prefetch);
                String2CPInstructionType.put( "broadcast",  CPType.Broadcast);
+               String2CPInstructionType.put( "trigremote",  CPType.TrigRemote);
                String2CPInstructionType.put( Local.OPCODE, CPType.Local);
                
                String2CPInstructionType.put( "sql", CPType.Sql);
@@ -477,7 +479,10 @@ public class CPInstructionParser extends InstructionParser
                                
                        case Broadcast:
                                return 
BroadcastCPInstruction.parseInstruction(str);
-                       
+
+                       case TrigRemote:
+                               return 
TriggerRemoteOpsCPInstruction.parseInstruction(str);
+
                        default:
                                throw new DMLRuntimeException("Invalid CP 
Instruction Type: " + cptype );
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index acb1fd1dae..144760b3d9 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -47,7 +47,7 @@ public abstract class CPInstruction extends Instruction
                MultiReturnParameterizedBuiltin, ParameterizedBuiltin, 
MultiReturnBuiltin,
                Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick, 
Local,
                MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, 
Compression, DeCompression, SpoofFused,
-               StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, 
Sql, Prefetch, Broadcast }
+               StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, 
Sql, Prefetch, Broadcast, TrigRemote }
 
        protected final CPType _cptype;
        protected final boolean _requiresLabelUpdate;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index 9d95a58dc3..db9bbb1b84 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -51,6 +51,6 @@ public class PrefetchCPInstruction extends UnaryCPInstruction 
{
                // In that case this Prefetch instruction will act like a NOOP. 
                if (CommonThreadPool.triggerRemoteOPsPool == null)
                        CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
-               CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerRemoteOperationsTask(ec.getMatrixObject(output)));
+               CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerPrefetchTask(ec.getMatrixObject(output)));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
similarity index 87%
copy from 
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
copy to 
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
index 255a2d61f6..26a1c8e8bb 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
@@ -24,10 +24,10 @@ import 
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.utils.stats.SparkStatistics;
 
-public class TriggerRemoteOperationsTask implements Runnable {
+public class TriggerPrefetchTask implements Runnable {
        MatrixObject _prefetchMO;
 
-       public TriggerRemoteOperationsTask(MatrixObject mo) {
+       public TriggerPrefetchTask(MatrixObject mo) {
                _prefetchMO = mo;
        }
 
@@ -35,8 +35,8 @@ public class TriggerRemoteOperationsTask implements Runnable {
        public void run() {
                boolean prefetched = false;
                synchronized (_prefetchMO) {
-                       // Having this check if operations are pending inside 
the 
-                       // critical section safeguards against concurrent rmVar.
+                       // Having this check inside the critical section
+                       // safeguards against concurrent rmVar.
                        if (_prefetchMO.isPendingRDDOps() || 
_prefetchMO.isFederated()) {
                                // TODO: Add robust runtime constraints for 
federated prefetch
                                // Execute and bring the result to local
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
index 255a2d61f6..63e6f56fd5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
@@ -19,37 +19,32 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
+import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.lops.Checkpoint;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.utils.stats.SparkStatistics;
 
 public class TriggerRemoteOperationsTask implements Runnable {
-       MatrixObject _prefetchMO;
+       MatrixObject _remoteOperationsRoot;
 
        public TriggerRemoteOperationsTask(MatrixObject mo) {
-               _prefetchMO = mo;
+               _remoteOperationsRoot = mo;
        }
 
        @Override
        public void run() {
-               boolean prefetched = false;
-               synchronized (_prefetchMO) {
-                       // Having this check if operations are pending inside 
the 
-                       // critical section safeguards against concurrent rmVar.
-                       if (_prefetchMO.isPendingRDDOps() || 
_prefetchMO.isFederated()) {
-                               // TODO: Add robust runtime constraints for 
federated prefetch
-                               // Execute and bring the result to local
-                               _prefetchMO.acquireReadAndRelease();
-                               prefetched = true;
+               boolean triggered = false;
+               synchronized (_remoteOperationsRoot) {
+                       if (_remoteOperationsRoot.isPendingRDDOps()) {
+                               JavaPairRDD<?, ?> rdd = 
_remoteOperationsRoot.getRDDHandle().getRDD();
+                               
rdd.persist(Checkpoint.DEFAULT_STORAGE_LEVEL).count();
+                               
_remoteOperationsRoot.getRDDHandle().setCheckpointRDD(true);
+                               triggered = true;
                        }
                }
-               if (DMLScript.STATISTICS && prefetched) {
-                       if (_prefetchMO.isFederated())
-                               FederatedStatistics.incAsyncPrefetchCount(1);
-                       else
-                               SparkStatistics.incAsyncPrefetchCount(1);
-               }
-       }
 
+               if (DMLScript.STATISTICS && triggered)
+                       SparkStatistics.incAsyncTriggerRemoteCount(1);
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
similarity index 71%
copy from 
src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
copy to 
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
index 9d95a58dc3..98d7f440c4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
@@ -16,7 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
 package org.apache.sysds.runtime.instructions.cp;
 
 import java.util.concurrent.Executors;
@@ -26,29 +25,27 @@ import 
org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 
-public class PrefetchCPInstruction extends UnaryCPInstruction {
-       private PrefetchCPInstruction(Operator op, CPOperand in, CPOperand out, 
String opcode, String istr) {
-               super(CPType.Prefetch, op, in, out, opcode, istr);
+public class TriggerRemoteOpsCPInstruction extends UnaryCPInstruction {
+       private TriggerRemoteOpsCPInstruction(Operator op, CPOperand in, 
CPOperand out, String opcode, String istr) {
+               super(CPType.TrigRemote, op, in, out, opcode, istr);
        }
-       
-       public static PrefetchCPInstruction parseInstruction (String str) {
+
+       public static TriggerRemoteOpsCPInstruction parseInstruction (String 
str) {
                InstructionUtils.checkNumFields(str, 2);
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
                CPOperand in = new CPOperand(parts[1]);
                CPOperand out = new CPOperand(parts[2]);
-               return new PrefetchCPInstruction(null, in, out, opcode, str);
+               return new TriggerRemoteOpsCPInstruction(null, in, out, opcode, 
str);
        }
 
        @Override
        public void processInstruction(ExecutionContext ec) {
-               //TODO: handle non-matrix objects
+               // TODO: Operator placement.
+               // Note for testing: write a method in the Dag class to place 
this operator
+               // after Spark MMRJ. Then execute 
PrefetchRDDTest.testAsyncSparkOPs3.
                ec.setVariable(output.getName(), ec.getMatrixObject(input1));
 
-               // Note, a Prefetch instruction doesn't guarantee an 
asynchronous execution.
-               // If the next instruction which takes this output as an input 
comes before
-               // the prefetch thread triggers, that instruction will start 
the operations.
-               // In that case this Prefetch instruction will act like a NOOP. 
                if (CommonThreadPool.triggerRemoteOPsPool == null)
                        CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
                CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerRemoteOperationsTask(ec.getMatrixObject(output)));
diff --git a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java 
b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
index 5263dbd119..3965feafdc 100644
--- a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
@@ -34,6 +34,7 @@ public class SparkStatistics {
        private static final LongAdder broadcastCount = new LongAdder();
        private static final LongAdder asyncPrefetchCount = new LongAdder();
        private static final LongAdder asyncBroadcastCount = new LongAdder();
+       private static final LongAdder asyncTriggerRemoteCount = new 
LongAdder();
 
        public static boolean createdSparkContext() {
                return ctxCreateTime > 0;
@@ -76,6 +77,10 @@ public class SparkStatistics {
                asyncBroadcastCount.add(c);
        }
 
+       public static void incAsyncTriggerRemoteCount(long c) {
+               asyncTriggerRemoteCount.add(c);
+       }
+
        public static long getSparkCollectCount() {
                return collectCount.longValue();
        }
@@ -88,6 +93,10 @@ public class SparkStatistics {
                return asyncBroadcastCount.longValue();
        }
 
+       public static long getAsyncTriggerRemoteCount() {
+               return asyncTriggerRemoteCount.longValue();
+       }
+
        public static void reset() {
                ctxCreateTime = 0;
                parallelizeTime.reset();
@@ -98,6 +107,7 @@ public class SparkStatistics {
                collectCount.reset();
                asyncPrefetchCount.reset();
                asyncBroadcastCount.reset();
+               asyncTriggerRemoteCount.reset();
        }
 
        public static String displayStatistics() {
@@ -114,8 +124,8 @@ public class SparkStatistics {
                                                broadcastTime.longValue()*1e-9,
                                                collectTime.longValue()*1e-9));
                if (OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS)
-                       sb.append("Spark async. count (pf,bc): \t" + 
-                                       String.format("%d/%d.\n", 
getAsyncPrefetchCount(), getAsyncBroadcastCount()));
+                       sb.append("Spark async. count (pf,bc,tr): \t" +
+                                       String.format("%d/%d/%d.\n", 
getAsyncPrefetchCount(), getAsyncBroadcastCount(), 
getAsyncTriggerRemoteCount()));
                return sb.toString();
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index 5a884724f6..61279bd036 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -64,7 +64,7 @@ public class PrefetchRDDTest extends AutomatedTestBase {
 
        @Test
        public void testAsyncSparkOPs3() {
-               //SP action type consumer. No Prefetch.
+               //SP binary consumer, followed by an action. No Prefetch.
                runTest(TEST_NAME+"3");
        }
        
diff --git a/src/test/scripts/functions/async/PrefetchRDD3.dml 
b/src/test/scripts/functions/async/PrefetchRDD3.dml
index 15286e3034..340115b46e 100644
--- a/src/test/scripts/functions/async/PrefetchRDD3.dml
+++ b/src/test/scripts/functions/async/PrefetchRDD3.dml
@@ -30,7 +30,7 @@ sp2 = sp1 %*% t(sp1);
 v = ((v + v) * 1 - v) / (1+1);
 v = ((v + v) * 2 - v) / (2+1);
 
-# CP binary triggers the DAG of SP operations
-cp = sp2 + sum(v);
-R = sum(cp);
+# SP sum triggers the DAG of SP operations
+SP = sp2 + sum(v); #spark transformation
+R = sum(SP); #action
 write(R, $1, format="text");

Reply via email to