This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 53c8157746 [SYSTEMDS-3519] Extend Prefetch instruction for GPU to CP
53c8157746 is described below

commit 53c81577465ee6d42b10f5a29d03cd504f978d56
Author: Arnab Phani <[email protected]>
AuthorDate: Mon Sep 25 13:20:27 2023 +0200

    [SYSTEMDS-3519] Extend Prefetch instruction for GPU to CP
    
    This patch enables prefetch instruction to copy intermediates from
    GPU to local memory asynchronously. As we reuse prefetch, this change
    also allows removing synchronization barriers between GPU and CPU via
    reusing prefetched matrix blocks.
---
 .../java/org/apache/sysds/conf/ConfigurationManager.java |  4 ++--
 src/main/java/org/apache/sysds/conf/DMLConfig.java       |  6 +++---
 src/main/java/org/apache/sysds/hops/OptimizerUtils.java  | 10 +++++-----
 .../sysds/lops/compile/linearization/ILinearize.java     |  1 +
 .../java/org/apache/sysds/lops/rewrite/LopRewriter.java  |  2 +-
 .../apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java | 10 ++++++++++
 .../runtime/controlprogram/caching/CacheableData.java    |  7 ++++++-
 .../runtime/instructions/cp/TriggerPrefetchTask.java     |  3 ++-
 .../org/apache/sysds/runtime/lineage/LineageCache.java   |  8 ++++++--
 .../sysds/runtime/lineage/LineageCacheStatistics.java    | 16 +++++++---------
 src/main/java/org/apache/sysds/utils/Statistics.java     |  2 +-
 .../org/apache/sysds/utils/stats/SparkStatistics.java    |  2 +-
 .../test/functions/async/MaxParallelizeOrderTest.java    |  4 ++--
 .../sysds/test/functions/async/PrefetchRDDTest.java      |  4 ++--
 .../sysds/test/functions/async/ReuseAsyncOpTest.java     |  4 ++--
 .../sysds/test/functions/lineage/GPUFullReuseTest.java   | 15 +++++++++------
 16 files changed, 60 insertions(+), 38 deletions(-)

diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java 
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index a4b5c0ffec..1ac4d13974 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -262,8 +262,8 @@ public class ConfigurationManager{
        }
 
        public static boolean isPrefetchEnabled() {
-               return 
(getDMLConfig().getBooleanValue(DMLConfig.ASYNC_SPARK_PREFETCH)
-                       || OptimizerUtils.ASYNC_PREFETCH_SPARK);
+               return (getDMLConfig().getBooleanValue(DMLConfig.ASYNC_PREFETCH)
+                       || OptimizerUtils.ASYNC_PREFETCH);
        }
 
        public static boolean isMaxPrallelizeEnabled() {
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java 
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index 103cf0a01e..767026a161 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -131,7 +131,7 @@ public class DMLConfig
        public static final int DEFAULT_FEDERATED_PORT = 4040; // borrowed 
default Spark Port
        public static final int DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS = 8;
        /** Asynchronous triggering of Spark OPs and operator placement **/
-       public static final String ASYNC_SPARK_PREFETCH = 
"sysds.async.prefetch";  // boolean: enable asynchronous prefetching spark 
intermediates
+       public static final String ASYNC_PREFETCH = "sysds.async.prefetch";  // 
boolean: enable asynchronous prefetching spark/gpu intermediates
        public static final String ASYNC_SPARK_BROADCAST = 
"sysds.async.broadcast";  // boolean: enable asynchronous broadcasting CP 
intermediates
        public static final String ASYNC_SPARK_CHECKPOINT = 
"sysds.async.checkpoint";  // boolean: enable compile-time persisting of Spark 
intermediates
        //internal config
@@ -207,7 +207,7 @@ public class DMLConfig
                _defaultVals.put(FEDERATED_MONITOR_FREQUENCY, "3");
                _defaultVals.put(FEDERATED_COMPRESSION, "none");
                _defaultVals.put(PRIVACY_CONSTRAINT_MOCK, null);
-               _defaultVals.put(ASYNC_SPARK_PREFETCH,   "false" );
+               _defaultVals.put(ASYNC_PREFETCH,   "false" );
                _defaultVals.put(ASYNC_SPARK_BROADCAST,  "false" );
                _defaultVals.put(ASYNC_SPARK_CHECKPOINT,  "false" );
        }
@@ -463,7 +463,7 @@ public class DMLConfig
                        FLOATING_POINT_PRECISION, GPU_EVICTION_POLICY, 
LOCAL_SPARK_NUM_THREADS, EVICTION_SHADOW_BUFFERSIZE,
                        GPU_MEMORY_ALLOCATOR, GPU_MEMORY_UTILIZATION_FACTOR, 
USE_SSL_FEDERATED_COMMUNICATION,
                        DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT, 
FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY, FEDERATED_COMPRESSION,
-                       ASYNC_SPARK_PREFETCH, ASYNC_SPARK_BROADCAST, 
ASYNC_SPARK_CHECKPOINT
+                       ASYNC_PREFETCH, ASYNC_SPARK_BROADCAST, 
ASYNC_SPARK_CHECKPOINT
                }; 
                
                StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 1a7cc1c02b..8da0ff110d 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -285,16 +285,16 @@ public class OptimizerUtils
        public static boolean ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
 
        /**
-        * Enable prefetch and broadcast. Prefetch asynchronously calls 
acquireReadAndRelease() to trigger a chain of spark
-        * transformations, which would would otherwise make the next 
instruction wait till completion. Broadcast allows
+        * Enable prefetch and broadcast. Prefetch asynchronously calls 
acquireReadAndRelease() to trigger remote
+        * operations, which would otherwise make the next instruction wait 
till completion. Broadcast allows
         * asynchronously transferring the data to all the nodes.
         */
-       public static boolean ASYNC_PREFETCH_SPARK = false;
+       public static boolean ASYNC_PREFETCH = false; //both Spark and GPU
        public static boolean ASYNC_BROADCAST_SPARK = false;
        public static boolean ASYNC_CHECKPOINT_SPARK = false;
 
        /**
-        * Heuristic-based instruction ordering to maximize inter-operator 
parallelism.
+        * Heuristic-based instruction ordering to maximize inter-operator 
PARALLELISM.
         * Place the Spark operator chains first and trigger them to execute in 
parallel.
         */
        public static boolean MAX_PARALLELIZE_ORDER = false;
@@ -308,7 +308,7 @@ public class OptimizerUtils
        /**
         * Rule-based operator placement policy for GPU.
         */
-       public static boolean RULE_BASED_GPU_EXEC = false;
+       public static boolean RULE_BASED_GPU_EXEC = true;
 
        //////////////////////
        // Optimizer levels //
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java 
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index 3c0aa61692..364dd662b8 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -210,6 +210,7 @@ public class ILinearize {
                        final_v = depthFirst(v);
 
                return final_v;
+               //TODO: Support GPU operator chains
        }
 
        // Place the operators in a depth-first manner, but order
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java 
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
index 2b054d9b2b..8d2c0a63f8 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -40,11 +40,11 @@ public class LopRewriter
        public LopRewriter() {
                _lopSBRuleSet = new ArrayList<>();
                // Add rewrite rules (single and multi-statement block)
+               _lopSBRuleSet.add(new RewriteUpdateGPUPlacements());
                _lopSBRuleSet.add(new RewriteAddPrefetchLop());
                _lopSBRuleSet.add(new RewriteAddBroadcastLop());
                _lopSBRuleSet.add(new RewriteAddChkpointLop());
                _lopSBRuleSet.add(new RewriteAddChkpointInLoop());
-               _lopSBRuleSet.add(new RewriteUpdateGPUPlacements());
                // TODO: A rewrite pass to remove less effective chkpoints
                // Last rewrite to reset Lop IDs in a depth-first manner
                _lopSBRuleSet.add(new RewriteFixIDs());
diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
index 91b7f81e71..4567e88d52 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
@@ -90,6 +90,10 @@ public class RewriteAddPrefetchLop extends LopRewriteRule
        }
 
        private boolean isPrefetchNeeded(Lop lop) {
+               return isPrefetchFromSparkNeeded(lop) || 
isPrefetchFromGPUNeeded(lop);
+       }
+
+       private boolean isPrefetchFromSparkNeeded(Lop lop) {
                // Run Prefetch for a Spark instruction if the instruction is a 
Transformation
                // and the output is consumed by only CP instructions.
                boolean transformOP = lop.getExecType() == Types.ExecType.SPARK 
&& lop.getAggType() != AggBinaryOp.SparkAggType.SINGLE_BLOCK
@@ -119,4 +123,10 @@ public class RewriteAddPrefetchLop extends LopRewriteRule
                        && (lop.isAllOutputsCP() || 
OperatorOrderingUtils.isCollectForBroadcast(lop))
                        && lop.getDataType() == Types.DataType.MATRIX;
        }
+
+       private boolean isPrefetchFromGPUNeeded(Lop lop) {
+               // Prefetch a GPU intermediate if all the outputs are CP.
+               return lop.getDataType() == Types.DataType.MATRIX
+                       && lop.isExecGPU() && lop.isAllOutputsCP();
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index a10b13284f..6961173af2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -1280,7 +1280,12 @@ public abstract class CacheableData<T extends 
CacheBlock<?>> extends Data
        public boolean isPendingRDDOps() {
                return isEmpty(true) && _data == null && (_rddHandle != null && 
_rddHandle.hasBackReference());
        }
-       
+
+       public boolean isDeviceToHostCopy() {
+               boolean isGpuOP = isEmpty(true) && _data == null && _gpuObjects 
!= null;
+               return isGpuOP && _gpuObjects.values().stream().anyMatch(gobj 
-> (gobj != null && gobj.isDirty()));
+       }
+
        protected void setEmpty() {
                _cacheStatus = CacheStatus.EMPTY;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
index 78857c5a17..f1d5a8d3f6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
@@ -49,7 +49,8 @@ public class TriggerPrefetchTask implements Runnable {
                synchronized (_prefetchMO) {
                        // Having this check inside the critical section
                        // safeguards against concurrent rmVar.
-                       if (_prefetchMO.isPendingRDDOps() || 
_prefetchMO.isFederated()) {
+                       if (_prefetchMO.isPendingRDDOps() || 
_prefetchMO.isDeviceToHostCopy()
+                               || _prefetchMO.isFederated()) {
                                // TODO: Add robust runtime constraints for 
federated prefetch
                                // Execute and bring the result to local
                                mb = _prefetchMO.acquireReadAndRelease();
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index d05ea234a1..7aacfb5151 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -641,8 +641,9 @@ public class LineageCache
 
                                // Scalar gpu intermediates is always copied 
back to host. 
                                // No need to cache the GPUobj for scalar 
intermediates.
+                               instLI = ec.getLineageItem(((GPUInstruction) 
inst)._output);
                                if (liGPUObj == null)
-                                       liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((GPUInstruction)inst)._output)));
+                                       liData = Arrays.asList(Pair.of(instLI, 
ec.getVariable(((GPUInstruction) inst)._output)));
                        }
                        else if (inst instanceof ComputationSPInstruction
                                && (ec.getVariable(((ComputationSPInstruction) 
inst).output) instanceof MatrixObject)
@@ -1463,6 +1464,8 @@ public class LineageCache
 
                
LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
                if (e.isGPUObject()) LineageCacheStatistics.incrementGpuHits();
+               if (inst.getOpcode().equals("prefetch") && 
DMLScript.USE_ACCELERATOR)
+                       LineageCacheStatistics.incrementGpuPrefetch();
                if (e.isRDDPersist()) {
                        if 
(SparkExecutionContext.isRDDCached(e.getRDDObject().getRDD().id()))
                                
LineageCacheStatistics.incrementRDDPersistHits(); //persisted in the executors
@@ -1470,7 +1473,8 @@ public class LineageCache
                                LineageCacheStatistics.incrementRDDHits();  
//only locally cached
                }
                if (e.isMatrixValue() || e.isScalarValue()) {
-                       if (inst instanceof ComputationSPInstruction || 
inst.getOpcode().equals("prefetch"))
+                       if (inst instanceof ComputationSPInstruction
+                               || (inst.getOpcode().equals("prefetch") && 
!DMLScript.USE_ACCELERATOR))
                                // Single_block Spark instructions (sync/async) 
and prefetch
                                
LineageCacheStatistics.incrementSparkCollectHits();
                        else
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index 00fd36c378..be7460f6fd 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -44,7 +44,7 @@ public class LineageCacheStatistics {
        private static final LongAdder _ctimeProbe      = new LongAdder();
        // Bellow entries are specific to gpu lineage cache
        private static final LongAdder _numHitsGpu      = new LongAdder();
-       private static final LongAdder _numAsyncEvictGpu= new LongAdder();
+       private static final LongAdder _numPrefetchGpu= new LongAdder();
        private static final LongAdder _numSyncEvictGpu = new LongAdder();
        private static final LongAdder _numRecycleGpu   = new LongAdder();
        private static final LongAdder _numDelGpu       = new LongAdder();
@@ -74,7 +74,7 @@ public class LineageCacheStatistics {
                _ctimeProbe.reset();
                _evtimeGpu.reset();
                _numHitsGpu.reset();
-               _numAsyncEvictGpu.reset();
+               _numPrefetchGpu.reset();
                _numSyncEvictGpu.reset();
                _numRecycleGpu.reset();
                _numDelGpu.reset();
@@ -210,9 +210,9 @@ public class LineageCacheStatistics {
                _numHitsGpu.increment();
        }
 
-       public static void incrementGpuAsyncEvicts() {
-               // Number of gpu cache entries moved to cpu cache via the 
background thread
-               _numAsyncEvictGpu.increment();
+       public static void incrementGpuPrefetch() {
+               // Number of reuse of GPU to host prefetches (asynchronous)
+               _numPrefetchGpu.increment();
        }
 
        public static void incrementGpuSyncEvicts() {
@@ -318,9 +318,7 @@ public class LineageCacheStatistics {
                StringBuilder sb = new StringBuilder();
                sb.append(_numHitsGpu.longValue());
                sb.append("/");
-               sb.append(_numAsyncEvictGpu.longValue());
-               sb.append("/");
-               sb.append(_numSyncEvictGpu.longValue());
+               sb.append(_numPrefetchGpu.longValue());
                return sb.toString();
        }
 
@@ -339,7 +337,7 @@ public class LineageCacheStatistics {
        }
 
        public static boolean ifGpuStats() {
-               return (_numHitsGpu.longValue() + _numAsyncEvictGpu.longValue()
+               return (_numHitsGpu.longValue() + _numPrefetchGpu.longValue()
                        + _numSyncEvictGpu.longValue() + 
_numRecycleGpu.longValue()
                        + _numDelGpu.longValue() + _evtimeGpu.longValue()) != 0;
        }
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 6978507179..01c9682d90 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -639,7 +639,7 @@ public class Statistics
                                sb.append("LinCache hits (Mem/FS/Del): \t" + 
LineageCacheStatistics.displayHits() + ".\n");
                                sb.append("LinCache MultiLevel (Ins/SB/Fn):" + 
LineageCacheStatistics.displayMultiLevelHits() + ".\n");
                                if (LineageCacheStatistics.ifGpuStats()) {
-                                       sb.append("LinCache GPU 
(Hit/Async/Sync): \t" + LineageCacheStatistics.displayGpuStats() + ".\n");
+                                       sb.append("LinCache GPU (Hit/PF): \t" + 
LineageCacheStatistics.displayGpuStats() + ".\n");
                                        sb.append("LinCache GPU (Recyc/Del): 
\t" + LineageCacheStatistics.displayGpuPointerStats() + ".\n");
                                        sb.append("LinCache GPU evict time: \t" 
+ LineageCacheStatistics.displayGpuEvictTime() + " sec.\n");
                                }
diff --git a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java 
b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
index ae21ea0672..8c110a0b92 100644
--- a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
@@ -131,7 +131,7 @@ public class SparkStatistics {
                                                
parallelizeTime.longValue()*1e-9,
                                                broadcastTime.longValue()*1e-9,
                                                collectTime.longValue()*1e-9));
-               sb.append("Spark async. count (pf,bc,op): \t" +
+               sb.append("Async. OP count (pf,bc,op): \t" +
                                String.format("%d/%d/%d.\n", 
getAsyncPrefetchCount(), getAsyncBroadcastCount(), getAsyncSparkOpCount()));
                return sb.toString();
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
index eeb3f6d5e7..51a166c10d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
@@ -96,13 +96,13 @@ public class MaxParallelizeOrderTest extends 
AutomatedTestBase {
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
 
-                       OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
+                       OptimizerUtils.ASYNC_PREFETCH = true;
                        OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
                        if (testname.equalsIgnoreCase(TEST_NAME+"4"))
                                OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE 
= false;
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> R_mp = 
readDMLScalarFromOutputDir("R");
-                       OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+                       OptimizerUtils.ASYNC_PREFETCH = false;
                        OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
                        OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index 886a850d22..46ab3444df 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -103,9 +103,9 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
 
-                       OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
+                       OptimizerUtils.ASYNC_PREFETCH = true;
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
-                       OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+                       OptimizerUtils.ASYNC_PREFETCH = false;
                        OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
                        HashMap<MatrixValue.CellIndex, Double> R_pf = 
readDMLScalarFromOutputDir("R");
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
index 7666a30184..7d700cc3b6 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
@@ -136,12 +136,12 @@ public class ReuseAsyncOpTest extends AutomatedTestBase {
        private void enableAsync() {
                OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
                OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
-               OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
+               OptimizerUtils.ASYNC_PREFETCH = true;
        }
 
        private void disableAsync() {
                OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
                OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
-               OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+               OptimizerUtils.ASYNC_PREFETCH = false;
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/GPUFullReuseTest.java 
b/src/test/java/org/apache/sysds/test/functions/lineage/GPUFullReuseTest.java
index 1a0665c187..47bfd8c4e8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/lineage/GPUFullReuseTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/GPUFullReuseTest.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.lineage.Lineage;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
@@ -43,21 +44,21 @@ public class GPUFullReuseTest extends AutomatedTestBase{
        protected static final String TEST_NAME = "LineageReuseGPU";
        protected static final int TEST_VARIANTS = 4;
        protected String TEST_CLASS_DIR = TEST_DIR + 
GPUFullReuseTest.class.getSimpleName() + "/";
-       
+
        @BeforeClass
        public static void checkGPU() {
                // Skip all the tests if no GPU is available
                // FIXME: Fails to skip if gpu available but no libraries
                Assume.assumeTrue(TestUtils.isGPUAvailable() == 
cudaError.cudaSuccess);
        }
-       
+
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
                for( int i=1; i<=TEST_VARIANTS; i++ )
                        addTestConfiguration(TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i));
        }
-       
+
        @Test
        public void ReuseAggBin() {           //reuse AggregateBinary and sum
                testLineageTraceExec(TEST_NAME+"1");
@@ -90,9 +91,10 @@ public class GPUFullReuseTest extends AutomatedTestBase{
                proArgs.add(output("R"));
                programArgs = proArgs.toArray(new String[proArgs.size()]);
                fullDMLScriptName = getScript();
-               
+
                Lineage.resetInternalState();
                //run the test
+               OptimizerUtils.ASYNC_PREFETCH = true;
                runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                HashMap<MatrixValue.CellIndex, Double> R_orig = 
readDMLMatrixFromOutputDir("R");
 
@@ -104,14 +106,15 @@ public class GPUFullReuseTest extends AutomatedTestBase{
                proArgs.add(output("R"));
                programArgs = proArgs.toArray(new String[proArgs.size()]);
                fullDMLScriptName = getScript();
-               
+
                Lineage.resetInternalState();
                //run the test
                runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+               OptimizerUtils.ASYNC_PREFETCH = false;
                AutomatedTestBase.TEST_GPU = false;
                HashMap<MatrixValue.CellIndex, Double> R_reused = 
readDMLMatrixFromOutputDir("R");
 
-               //compare results 
+               //compare results
                TestUtils.compareMatrices(R_orig, R_reused, 1e-6, "Origin", 
"Reused");
 
                if( testname.endsWith("3") ) { //function reuse

Reply via email to