This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new afa315d831 [SYSTEMDS-3443] Rename triggerremote instruction to 
checkpoint_e
afa315d831 is described below

commit afa315d8315b238ad29edf5e64c3d6954afd8cfd
Author: Arnab Phani <[email protected]>
AuthorDate: Wed Nov 30 18:59:13 2022 +0100

    [SYSTEMDS-3443] Rename triggerremote instruction to checkpoint_e
    
    This patch extends the checkpoint instruction with an asynchronous
    checkpoint_e (eager checkpoint) version. This operator eagerly triggers
    a chain of Spark operations and persist the distributed results.
    
    Closes #1744
---
 .../apache/sysds/conf/ConfigurationManager.java    |  4 +
 src/main/java/org/apache/sysds/conf/DMLConfig.java |  5 +-
 .../java/org/apache/sysds/hops/OptimizerUtils.java |  1 +
 .../java/org/apache/sysds/lops/Checkpoint.java     | 17 ++++-
 .../java/org/apache/sysds/lops/compile/Dag.java    | 86 ----------------------
 .../lops/compile/linearization/ILinearize.java     | 49 ++++++++++++
 .../runtime/instructions/CPInstructionParser.java  |  4 -
 .../runtime/instructions/SPInstructionParser.java  |  3 +-
 ...rationsTask.java => TriggerCheckpointTask.java} |  7 +-
 .../cp/TriggerRemoteOpsCPInstruction.java          | 53 -------------
 .../spark/CheckpointSPInstruction.java             | 19 ++++-
 .../apache/sysds/utils/stats/SparkStatistics.java  | 16 ++--
 12 files changed, 103 insertions(+), 161 deletions(-)

diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java 
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index bb6172993a..12764eacf4 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -247,6 +247,10 @@ public class ConfigurationManager
                return 
(getDMLConfig().getBooleanValue(DMLConfig.ASYNC_SPARK_BROADCAST)
                        || OptimizerUtils.ASYNC_BROADCAST_SPARK);
        }
+       public static boolean isCheckpointEnabled() {
+               return 
(getDMLConfig().getBooleanValue(DMLConfig.ASYNC_SPARK_CHECKPOINT)
+                       || OptimizerUtils.ASYNC_CHECKPOINT_SPARK);
+       }
 
        public static ILinearize.DagLinearization getLinearizationOrder() {
                if (OptimizerUtils.MAX_PARALLELIZE_ORDER)
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java 
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index 2225d89d89..dad670efd4 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -130,6 +130,7 @@ public class DMLConfig
        /** Asynchronous triggering of Spark OPs and operator placement **/
        public static final String ASYNC_SPARK_PREFETCH = 
"sysds.async.prefetch";  // boolean: enable asynchronous prefetching spark 
intermediates
        public static final String ASYNC_SPARK_BROADCAST = 
"sysds.async.broadcast";  // boolean: enable asynchronous broadcasting CP 
intermediates
+       public static final String ASYNC_SPARK_CHECKPOINT = 
"sysds.async.checkpoint";  // boolean: enable asynchronous persisting of Spark 
intermediates
        //internal config
        public static final String DEFAULT_SHARED_DIR_PERMISSION = "777"; //for 
local fs and DFS
        
@@ -202,6 +203,7 @@ public class DMLConfig
                _defaultVals.put(PRIVACY_CONSTRAINT_MOCK, null);
                _defaultVals.put(ASYNC_SPARK_PREFETCH,   "false" );
                _defaultVals.put(ASYNC_SPARK_BROADCAST,  "false" );
+               _defaultVals.put(ASYNC_SPARK_CHECKPOINT,  "false" );
        }
        
        public DMLConfig() {
@@ -454,7 +456,8 @@ public class DMLConfig
                        PRINT_GPU_MEMORY_INFO, AVAILABLE_GPUS, SYNCHRONIZE_GPU, 
EAGER_CUDA_FREE, FLOATING_POINT_PRECISION,
                        GPU_EVICTION_POLICY, LOCAL_SPARK_NUM_THREADS, 
EVICTION_SHADOW_BUFFERSIZE, GPU_MEMORY_ALLOCATOR,
                        GPU_MEMORY_UTILIZATION_FACTOR, 
USE_SSL_FEDERATED_COMMUNICATION, DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT,
-                       FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY, 
ASYNC_SPARK_PREFETCH, ASYNC_SPARK_BROADCAST
+                       FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY, 
ASYNC_SPARK_PREFETCH, ASYNC_SPARK_BROADCAST,
+                       ASYNC_SPARK_CHECKPOINT
                }; 
                
                StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index d2e9670362..580ddccf20 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -284,6 +284,7 @@ public class OptimizerUtils
         */
        public static boolean ASYNC_PREFETCH_SPARK = false;
        public static boolean ASYNC_BROADCAST_SPARK = false;
+       public static boolean ASYNC_CHECKPOINT_SPARK = false;
 
        /**
         * Heuristic-based instruction ordering to maximize inter-operator 
parallelism.
diff --git a/src/main/java/org/apache/sysds/lops/Checkpoint.java 
b/src/main/java/org/apache/sysds/lops/Checkpoint.java
index a15078bc70..6f56d287c0 100644
--- a/src/main/java/org/apache/sysds/lops/Checkpoint.java
+++ b/src/main/java/org/apache/sysds/lops/Checkpoint.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 
+import java.util.Arrays;
 
 /**
  * Lop for checkpoint operations. For example, on Spark, the semantic of a 
checkpoint 
@@ -38,13 +39,15 @@ import 
org.apache.sysds.runtime.instructions.InstructionUtils;
  */
 public class Checkpoint extends Lop 
 {
-       public static final String OPCODE = "chkpoint"; 
-        
+       public static final String DEFAULT_CP_OPCODE = "chkpoint";
+       public static final String ASYNC_CP_OPCODE = "chkpoint_e";
+
        public static final StorageLevel DEFAULT_STORAGE_LEVEL = 
StorageLevel.MEMORY_AND_DISK();
        public static final StorageLevel SER_STORAGE_LEVEL = 
StorageLevel.MEMORY_AND_DISK_SER();
        public static final boolean CHECKPOINT_SPARSE_CSR = true; 
 
        private StorageLevel _storageLevel;
+       private boolean _async = false;
        
 
        /**
@@ -55,16 +58,22 @@ public class Checkpoint extends Lop
         * @param dt data type
         * @param vt value type
         * @param level storage level
+        * @param isAsync true if eager and asynchronous checkpoint
         */
-       public Checkpoint(Lop input, DataType dt, ValueType vt, String level)  {
+       public Checkpoint(Lop input, DataType dt, ValueType vt, String level, 
boolean isAsync)  {
                super(Lop.Type.Checkpoint, dt, vt);
                addInput(input);
                input.addOutput(this);
                
                _storageLevel = StorageLevel.fromString(level);
+               _async = isAsync;
                lps.setProperties(inputs, ExecType.SPARK);
        }
 
+       public Checkpoint(Lop input, DataType dt, ValueType vt, String level)  {
+               this(input, dt, vt, level, false);
+       }
+
        public StorageLevel getStorageLevel()
        {
                return _storageLevel;
@@ -89,7 +98,7 @@ public class Checkpoint extends Lop
                
                return InstructionUtils.concatOperands(
                        getExecType().name(),
-                       OPCODE,
+                       _async ? ASYNC_CP_OPCODE : DEFAULT_CP_OPCODE,
                        getInputs().get(0).prepInputOperand(input1),
                        prepOutputOperand(output),
                        getStorageLevelString(_storageLevel));
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java 
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 2efbea8221..77325fb297 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -237,56 +237,6 @@ public class Dag<N extends Lop>
                }
        }
        
-       private static List<Lop> addPrefetchLop(List<Lop> nodes) {
-               List<Lop> nodesWithPrefetch = new ArrayList<>();
-               
-               //Find the Spark nodes with all CP outputs
-               for (Lop l : nodes) {
-                       nodesWithPrefetch.add(l);
-                       if (isPrefetchNeeded(l)) {
-                               //TODO: No prefetch if the parent is placed 
right after the spark OP
-                               //or push the parent further to increase 
parallelism
-                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
-                               //Construct a Prefetch lop that takes this 
Spark node as a input
-                               UnaryCP prefetch = new UnaryCP(l, 
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
-                               for (Lop outCP : oldOuts) {
-                                       //Rewire l -> outCP to l -> Prefetch -> 
outCP
-                                       prefetch.addOutput(outCP);
-                                       outCP.replaceInput(l, prefetch);
-                                       l.removeOutput(outCP);
-                                       //FIXME: Rewire _inputParams when 
needed (e.g. GroupedAggregate)
-                               }
-                               //Place it immediately after the Spark lop in 
the node list
-                               nodesWithPrefetch.add(prefetch);
-                       }
-               }
-               return nodesWithPrefetch;
-       }
-
-       private static List<Lop> addBroadcastLop(List<Lop> nodes) {
-               List<Lop> nodesWithBroadcast = new ArrayList<>();
-               
-               for (Lop l : nodes) {
-                       nodesWithBroadcast.add(l);
-                       if (isBroadcastNeeded(l)) {
-                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
-                               //Construct a Broadcast lop that takes this 
Spark node as an input
-                               UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST, 
l.getDataType(), l.getValueType(), ExecType.CP);
-                               //FIXME: Wire Broadcast only with the necessary 
outputs
-                               for (Lop outCP : oldOuts) {
-                                       //Rewire l -> outCP to l -> Broadcast 
-> outCP
-                                       bc.addOutput(outCP);
-                                       outCP.replaceInput(l, bc);
-                                       l.removeOutput(outCP);
-                                       //FIXME: Rewire _inputParams when 
needed (e.g. GroupedAggregate)
-                               }
-                               //Place it immediately after the Spark lop in 
the node list
-                               nodesWithBroadcast.add(bc);
-                       }
-               }
-               return nodesWithBroadcast;
-       }
-       
        private ArrayList<Instruction> doPlainInstructionGen(StatementBlock sb, 
List<Lop> nodes)
        {
                //prepare basic instruction sets
@@ -319,42 +269,6 @@ public class Dag<N extends Lop>
                        && 
dnode.getOutputParameters().getLabel().equals(input.getOutputParameters().getLabel());
        }
        
-       private static boolean isPrefetchNeeded(Lop lop) {
-               // Run Prefetch for a Spark instruction if the instruction is a 
Transformation
-               // and the output is consumed by only CP instructions.
-               boolean transformOP = lop.getExecType() == ExecType.SPARK && 
lop.getAggType() != SparkAggType.SINGLE_BLOCK
-                               // Always Action operations
-                               && !(lop.getDataType() == DataType.SCALAR)
-                               && !(lop instanceof MapMultChain) && !(lop 
instanceof PickByCount)
-                               && !(lop instanceof MMZip) && !(lop instanceof 
CentralMoment)
-                               && !(lop instanceof CoVariance) 
-                               // Not qualified for prefetching
-                               && !(lop instanceof Checkpoint) && !(lop 
instanceof ReBlock)
-                               && !(lop instanceof CSVReBlock)
-                               // Cannot filter Transformation cases from 
Actions (FIXME)
-                               && !(lop instanceof MMTSJ) && !(lop instanceof 
UAggOuterChain)
-                               && !(lop instanceof ParameterizedBuiltin) && 
!(lop instanceof SpoofFused);
-
-               //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
-               boolean hasParameterizedOut = lop.getOutputs().stream()
-                               .anyMatch(out -> ((out instanceof 
ParameterizedBuiltin) 
-                                       || (out instanceof GroupedAggregate)
-                                       || (out instanceof GroupedAggregateM)));
-               //TODO: support non-matrix outputs
-               return transformOP && !hasParameterizedOut 
-                               && lop.isAllOutputsCP() && lop.getDataType() == 
DataType.MATRIX;
-       }
-       
-       private static boolean isBroadcastNeeded(Lop lop) {
-               // Asynchronously broadcast a matrix if that is produced by a 
CP instruction,
-               // and at least one Spark parent needs to broadcast this 
intermediate (eg. mapmm)
-               boolean isBc = lop.getOutputs().stream()
-                               .anyMatch(out -> (out.getBroadcastInput() == 
lop));
-               //TODO: Early broadcast objects that are bigger than a single 
block
-               //return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
-               return isBc && lop.getDataType() == DataType.MATRIX;
-       }
-       
        private static List<Instruction> 
deleteUpdatedTransientReadVariables(StatementBlock sb, List<Lop> nodeV) {
                List<Instruction> insts = new ArrayList<>();
                if ( sb == null ) //return modifiable list
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java 
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index d867a91f4a..e78530b33b 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -319,6 +319,31 @@ public interface ILinearize {
                return nodesWithBroadcast;
        }
 
+       private static List<Lop> addAsyncEagerCheckpointLop(List<Lop> nodes) {
+               List<Lop> nodesWithCheckpoint = new ArrayList<>();
+                // Find the Spark action nodes
+               for (Lop l : nodes) {
+                       if (isCheckpointNeeded(l)) {
+                               List<Lop> oldInputs = new 
ArrayList<>(l.getInputs());
+                               // Place a Checkpoint node just below this node 
(Spark action)
+                               for (Lop in : oldInputs) {
+                                       if (in.getExecType() != ExecType.SPARK)
+                                               continue;
+                                       // Rewire in -> l to in -> Checkpoint 
-> l
+                                       //UnaryCP checkpoint = new UnaryCP(in, 
OpOp1.TRIGREMOTE, in.getDataType(), in.getValueType(), ExecType.CP);
+                                       Lop checkpoint = new Checkpoint(in, 
in.getDataType(), in.getValueType(),
+                                               
Checkpoint.getDefaultStorageLevelString(), true);
+                                       checkpoint.addOutput(l);
+                                       l.replaceInput(in, checkpoint);
+                                       in.removeOutput(l);
+                                       nodesWithCheckpoint.add(checkpoint);
+                               }
+                       }
+                       nodesWithCheckpoint.add(l);
+               }
+               return nodesWithCheckpoint;
+       }
+
        private static boolean isPrefetchNeeded(Lop lop) {
                // Run Prefetch for a Spark instruction if the instruction is a 
Transformation
                // and the output is consumed by only CP instructions.
@@ -354,4 +379,28 @@ public interface ILinearize {
                //return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
                return isBc && lop.getDataType() == DataType.MATRIX;
        }
+
+       private static boolean isCheckpointNeeded(Lop lop) {
+               // Place checkpoint_e just before a Spark action (FIXME)
+               boolean actionOP = lop.getExecType() == ExecType.SPARK
+                               && ((lop.getAggType() == 
SparkAggType.SINGLE_BLOCK)
+                               // Always Action operations
+                               || (lop.getDataType() == DataType.SCALAR)
+                               || (lop instanceof MapMultChain) || (lop 
instanceof PickByCount)
+                               || (lop instanceof MMZip) || (lop instanceof 
CentralMoment)
+                               || (lop instanceof CoVariance) || (lop 
instanceof MMTSJ))
+                               // Not qualified for Checkpoint
+                               && !(lop instanceof Checkpoint) && !(lop 
instanceof ReBlock)
+                               && !(lop instanceof CSVReBlock)
+                               // Cannot filter Transformation cases from 
Actions (FIXME)
+                               && !(lop instanceof UAggOuterChain)
+                               && !(lop instanceof ParameterizedBuiltin) && 
!(lop instanceof SpoofFused);
+
+               //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
+               boolean hasParameterizedOut = lop.getOutputs().stream()
+                               .anyMatch(out -> ((out instanceof 
ParameterizedBuiltin)
+                                       || (out instanceof GroupedAggregate)
+                                       || (out instanceof GroupedAggregateM)));
+               return actionOP && !hasParameterizedOut;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index b83a78674b..78dec00c24 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -67,7 +67,6 @@ import 
org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.SqlCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.StringInitCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.TriggerRemoteOpsCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
@@ -482,9 +481,6 @@ public class CPInstructionParser extends InstructionParser
                        case Broadcast:
                                return 
BroadcastCPInstruction.parseInstruction(str);
 
-                       case TrigRemote:
-                               return 
TriggerRemoteOpsCPInstruction.parseInstruction(str);
-
                        default:
                                throw new DMLRuntimeException("Invalid CP 
Instruction Type: " + cptype );
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 0ff745de57..773153d6d4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -240,7 +240,8 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put("libsvmrblk", 
SPType.LIBSVMReblock);
 
                // Spark-specific instructions
-               String2SPInstructionType.put( Checkpoint.OPCODE, 
SPType.Checkpoint);
+               String2SPInstructionType.put( Checkpoint.DEFAULT_CP_OPCODE, 
SPType.Checkpoint);
+               String2SPInstructionType.put( Checkpoint.ASYNC_CP_OPCODE, 
SPType.Checkpoint);
                String2SPInstructionType.put( Compression.OPCODE, 
SPType.Compression);
                String2SPInstructionType.put( DeCompression.OPCODE, 
SPType.DeCompression);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerCheckpointTask.java
similarity index 89%
rename from 
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
rename to 
src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerCheckpointTask.java
index 63e6f56fd5..657460f62e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOperationsTask.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerCheckpointTask.java
@@ -25,10 +25,10 @@ import org.apache.sysds.lops.Checkpoint;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.utils.stats.SparkStatistics;
 
-public class TriggerRemoteOperationsTask implements Runnable {
+public class TriggerCheckpointTask implements Runnable {
        MatrixObject _remoteOperationsRoot;
 
-       public TriggerRemoteOperationsTask(MatrixObject mo) {
+       public TriggerCheckpointTask(MatrixObject mo) {
                _remoteOperationsRoot = mo;
        }
 
@@ -36,6 +36,7 @@ public class TriggerRemoteOperationsTask implements Runnable {
        public void run() {
                boolean triggered = false;
                synchronized (_remoteOperationsRoot) {
+                       // FIXME: Handle double execution
                        if (_remoteOperationsRoot.isPendingRDDOps()) {
                                JavaPairRDD<?, ?> rdd = 
_remoteOperationsRoot.getRDDHandle().getRDD();
                                
rdd.persist(Checkpoint.DEFAULT_STORAGE_LEVEL).count();
@@ -45,6 +46,6 @@ public class TriggerRemoteOperationsTask implements Runnable {
                }
 
                if (DMLScript.STATISTICS && triggered)
-                       SparkStatistics.incAsyncTriggerRemoteCount(1);
+                       SparkStatistics.incAsyncTriggerCheckpointCount(1);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
deleted file mode 100644
index 98d7f440c4..0000000000
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerRemoteOpsCPInstruction.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.sysds.runtime.instructions.cp;
-
-import java.util.concurrent.Executors;
-
-import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.matrix.operators.Operator;
-import org.apache.sysds.runtime.util.CommonThreadPool;
-
-public class TriggerRemoteOpsCPInstruction extends UnaryCPInstruction {
-       private TriggerRemoteOpsCPInstruction(Operator op, CPOperand in, 
CPOperand out, String opcode, String istr) {
-               super(CPType.TrigRemote, op, in, out, opcode, istr);
-       }
-
-       public static TriggerRemoteOpsCPInstruction parseInstruction (String 
str) {
-               InstructionUtils.checkNumFields(str, 2);
-               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
-               String opcode = parts[0];
-               CPOperand in = new CPOperand(parts[1]);
-               CPOperand out = new CPOperand(parts[2]);
-               return new TriggerRemoteOpsCPInstruction(null, in, out, opcode, 
str);
-       }
-
-       @Override
-       public void processInstruction(ExecutionContext ec) {
-               // TODO: Operator placement.
-               // Note for testing: write a method in the Dag class to place 
this operator
-               // after Spark MMRJ. Then execute 
PrefetchRDDTest.testAsyncSparkOPs3.
-               ec.setVariable(output.getName(), ec.getMatrixObject(input1));
-
-               if (CommonThreadPool.triggerRemoteOPsPool == null)
-                       CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
-               CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerRemoteOperationsTask(ec.getMatrixObject(output)));
-       }
-}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
index f305630891..56455474b7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
@@ -35,6 +35,7 @@ import org.apache.sysds.runtime.frame.data.FrameBlock;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.BooleanObject;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.TriggerCheckpointTask;
 import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import 
org.apache.sysds.runtime.instructions.spark.functions.CopyFrameBlockFunction;
 import 
org.apache.sysds.runtime.instructions.spark.functions.CreateSparseBlockFunction;
@@ -43,9 +44,12 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import org.apache.sysds.utils.Statistics;
 
+import java.util.concurrent.Executors;
+
 public class CheckpointSPInstruction extends UnarySPInstruction {
        // default storage level
        private StorageLevel _level = null;
@@ -71,7 +75,20 @@ public class CheckpointSPInstruction extends 
UnarySPInstruction {
        @SuppressWarnings("unchecked")
        public void processInstruction(ExecutionContext ec) {
                SparkExecutionContext sec = (SparkExecutionContext)ec;
-               
+
+               // Asynchronously trigger count() and persist this RDD
+               // TODO: Synchronize. Avoid double execution
+               if (getOpcode().equals("chkpoint_e")) {  //eager checkpoint
+                       // Inplace replace output matrix object with the input 
matrix object
+                       // We will never use the output of the Spark count call
+                       ec.setVariable(output.getName(), 
ec.getCacheableData(input1));
+
+                       if (CommonThreadPool.triggerRemoteOPsPool == null)
+                               CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
+                       CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerCheckpointTask(ec.getMatrixObject(output)));
+                       return;
+               }
+
                // Step 1: early abort on non-existing or in-memory (cached) 
inputs
                // -------
                // (checkpoints are generated for all read only variables in 
loops; due to unbounded scoping and 
diff --git a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java 
b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
index 331634d1bd..0d1b5b05b9 100644
--- a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
@@ -33,7 +33,7 @@ public class SparkStatistics {
        private static final LongAdder broadcastCount = new LongAdder();
        private static final LongAdder asyncPrefetchCount = new LongAdder();
        private static final LongAdder asyncBroadcastCount = new LongAdder();
-       private static final LongAdder asyncTriggerRemoteCount = new 
LongAdder();
+       private static final LongAdder asyncTriggerCheckpointCount = new 
LongAdder();
 
        public static boolean createdSparkContext() {
                return ctxCreateTime > 0;
@@ -76,8 +76,8 @@ public class SparkStatistics {
                asyncBroadcastCount.add(c);
        }
 
-       public static void incAsyncTriggerRemoteCount(long c) {
-               asyncTriggerRemoteCount.add(c);
+       public static void incAsyncTriggerCheckpointCount(long c) {
+               asyncTriggerCheckpointCount.add(c);
        }
 
        public static long getSparkCollectCount() {
@@ -92,8 +92,8 @@ public class SparkStatistics {
                return asyncBroadcastCount.longValue();
        }
 
-       public static long getAsyncTriggerRemoteCount() {
-               return asyncTriggerRemoteCount.longValue();
+       public static long getasyncTriggerCheckpointCount() {
+               return asyncTriggerCheckpointCount.longValue();
        }
 
        public static void reset() {
@@ -106,7 +106,7 @@ public class SparkStatistics {
                collectCount.reset();
                asyncPrefetchCount.reset();
                asyncBroadcastCount.reset();
-               asyncTriggerRemoteCount.reset();
+               asyncTriggerCheckpointCount.reset();
        }
 
        public static String displayStatistics() {
@@ -122,8 +122,8 @@ public class SparkStatistics {
                                                
parallelizeTime.longValue()*1e-9,
                                                broadcastTime.longValue()*1e-9,
                                                collectTime.longValue()*1e-9));
-               sb.append("Spark async. count (pf,bc,tr): \t" +
-                               String.format("%d/%d/%d.\n", 
getAsyncPrefetchCount(), getAsyncBroadcastCount(), 
getAsyncTriggerRemoteCount()));
+               sb.append("Spark async. count (pf,bc,cp): \t" +
+                               String.format("%d/%d/%d.\n", 
getAsyncPrefetchCount(), getAsyncBroadcastCount(), 
getasyncTriggerCheckpointCount()));
                return sb.toString();
        }
 }

Reply via email to