This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 5182796632 [SYSTEMDS-3474] Lineage-based reuse of future-based 
instructions
5182796632 is described below

commit 5182796632bfb9173f5d2e7b2e7d20e434270bda
Author: Arnab Phani <[email protected]>
AuthorDate: Tue Dec 6 12:36:09 2022 +0100

    [SYSTEMDS-3474] Lineage-based reuse of future-based instructions
    
    This patch enables caching and reuse of future-based Spark
    actions.
    
    Closes #1747
---
 .../controlprogram/context/ExecutionContext.java   |  7 +-
 .../controlprogram/context/MatrixObjectFuture.java |  9 ++-
 .../instructions/cp/PrefetchCPInstruction.java     |  4 +-
 .../instructions/cp/TriggerPrefetchTask.java       |  8 +-
 .../spark/AggregateUnarySPInstruction.java         |  5 +-
 .../instructions/spark/CpmmSPInstruction.java      |  5 +-
 .../instructions/spark/MapmmSPInstruction.java     |  5 +-
 .../instructions/spark/TsmmSPInstruction.java      |  5 +-
 .../apache/sysds/runtime/lineage/LineageCache.java | 53 ++++--------
 .../sysds/runtime/lineage/LineageCacheConfig.java  |  4 +
 .../functions/async/LineageReuseSparkTest.java     | 44 +---------
 ...geReuseSparkTest.java => ReuseAsyncOpTest.java} | 93 +++++++++-------------
 .../{LineageReuseSpark2.dml => ReuseAsyncOp1.dml}  |  0
 .../{LineageReuseSpark2.dml => ReuseAsyncOp2.dml}  | 41 ++++++----
 14 files changed, 127 insertions(+), 156 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 0fa1340569..5e3e90f469 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -602,16 +602,21 @@ public class ExecutionContext {
                mo.release();
        }
 
-       public void setMatrixOutput(String varName, Future<MatrixBlock> fmb) {
+       public void setMatrixOutputAndLineage(String varName, 
Future<MatrixBlock> fmb, LineageItem li) {
                if (isAutoCreateVars() && !containsVariable(varName)) {
                        MatrixObject fmo = new 
MatrixObjectFuture(Types.ValueType.FP64,
                                OptimizerUtils.getUniqueTempFileName(), fmb);
                }
                MatrixObject mo = getMatrixObject(varName);
                MatrixObjectFuture fmo = new MatrixObjectFuture(mo, fmb);
+               fmo.setCacheLineage(li);
                setVariable(varName, fmo);
        }
 
+       public void setMatrixOutput(String varName, Future<MatrixBlock> fmb) {
+               setMatrixOutputAndLineage(varName, fmb, null);
+       }
+
        public void setMatrixOutput(String varName, MatrixBlock outputData, 
UpdateType flag) {
                if( isAutoCreateVars() && !containsVariable(varName) )
                        setVariable(varName, createMatrixObject(outputData));
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
index 3cbc7eff09..3c5581937b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.controlprogram.context;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.lineage.LineageCache;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
 import java.util.concurrent.Future;
@@ -59,8 +60,14 @@ public class MatrixObjectFuture extends MatrixObject
                                throw new DMLRuntimeException("MatrixObject not 
available to read.");
                        if(_data != null)
                                throw new DMLRuntimeException("_data must be 
null for future matrix object/block.");
+                       MatrixBlock out = null;
                        acquire(false, false);
-                       return _futureData.get();
+                       long t1 = System.nanoTime();
+                       out = _futureData.get();
+                       if (hasValidLineage())
+                               LineageCache.putValueAsyncOp(getCacheLineage(), 
this, out, t1);
+                               // FIXME: start time should indicate the actual 
start of the execution
+                       return out;
                }
 
                catch(Exception e) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index 192e165391..fa0d1c0e83 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -46,7 +46,7 @@ public class PrefetchCPInstruction extends UnaryCPInstruction 
{
        public void processInstruction(ExecutionContext ec) {
                // TODO: handle non-matrix objects
                ec.setVariable(output.getName(), ec.getMatrixObject(input1));
-               LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? 
this.getLineageItem(ec).getValue() : null;
+               LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ? 
getLineageItem(ec).getValue() : null;
 
                // Note, a Prefetch instruction doesn't guarantee an 
asynchronous execution.
                // If the next instruction which takes this output as an input 
comes before
@@ -54,6 +54,8 @@ public class PrefetchCPInstruction extends UnaryCPInstruction 
{
                // In that case this Prefetch instruction will act like a NOOP. 
                if (CommonThreadPool.triggerRemoteOPsPool == null)
                        CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
+               // Saving the lineage item inside the matrix object will 
replace the pre-attached
+               // lineage item (e.g. mapmm). Hence, passing separately.
                CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerPrefetchTask(ec.getMatrixObject(output), li));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
index b7c69d01f5..78857c5a17 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
@@ -24,6 +24,7 @@ import 
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
 import org.apache.sysds.runtime.lineage.LineageCache;
 import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.utils.stats.SparkStatistics;
 
 public class TriggerPrefetchTask implements Runnable {
@@ -43,6 +44,7 @@ public class TriggerPrefetchTask implements Runnable {
        @Override
        public void run() {
                boolean prefetched = false;
+               MatrixBlock mb = null;
                long t1 = System.nanoTime();
                synchronized (_prefetchMO) {
                        // Having this check inside the critical section
@@ -50,14 +52,14 @@ public class TriggerPrefetchTask implements Runnable {
                        if (_prefetchMO.isPendingRDDOps() || 
_prefetchMO.isFederated()) {
                                // TODO: Add robust runtime constraints for 
federated prefetch
                                // Execute and bring the result to local
-                               _prefetchMO.acquireReadAndRelease();
+                               mb = _prefetchMO.acquireReadAndRelease();
                                prefetched = true;
                        }
                }
 
                // Save the collected intermediate in the lineage cache
-               if (_inputLi != null)
-                       LineageCache.putValueAsyncOp(_inputLi, _prefetchMO, 
prefetched, t1);
+               if (_inputLi != null && mb != null)
+                       LineageCache.putValueAsyncOp(_inputLi, _prefetchMO, mb, 
t1);
 
                if (DMLScript.STATISTICS && prefetched) {
                        if (_prefetchMO.isFederated())
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 48d41fd602..50816aefe4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -39,6 +39,8 @@ import 
org.apache.sysds.runtime.instructions.spark.functions.AggregateDropCorrec
 import 
org.apache.sysds.runtime.instructions.spark.functions.FilterDiagMatrixBlocksFunction;
 import 
org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
 import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
@@ -117,7 +119,8 @@ public class AggregateUnarySPInstruction extends 
UnarySPInstruction {
                                                
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
                                        RDDAggregateTask task = new 
RDDAggregateTask(_optr, _aop, in, mc);
                                        Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
-                                       sec.setMatrixOutput(output.getName(), 
future_out);
+                                       LineageItem li = 
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : 
null;
+                                       
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
                                }
                                catch(Exception ex) {
                                        throw new DMLRuntimeException(ex);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 653596806d..79832eabe2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -37,6 +37,8 @@ import 
org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlock
 import org.apache.sysds.runtime.instructions.spark.functions.ReorgMapFunction;
 import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
 import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
@@ -113,7 +115,8 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
                                                
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
                                        CpmmMatrixVectorTask task = new 
CpmmMatrixVectorTask(in1, in2);
                                        Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
-                                       sec.setMatrixOutput(output.getName(), 
future_out);
+                                       LineageItem li = 
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : 
null;
+                                       
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
                                }
                                catch(Exception ex) {
                                        throw new DMLRuntimeException(ex);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index 29f28b604e..b0285b1bba 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -50,6 +50,8 @@ import 
org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
 import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
 import 
org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
 import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
@@ -146,7 +148,8 @@ public class MapmmSPInstruction extends 
AggregateBinarySPInstruction {
                                                
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
                                        RDDMapmmTask task = new  
RDDMapmmTask(in1, in2, type);
                                        Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
-                                       sec.setMatrixOutput(output.getName(), 
future_out);
+                                       LineageItem li = 
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : 
null;
+                                       
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
                                }
                                catch(Exception ex) { throw new 
DMLRuntimeException(ex); }
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
index 17cef61158..acba784bf2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
@@ -31,6 +31,8 @@ import 
org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -74,7 +76,8 @@ public class TsmmSPInstruction extends UnarySPInstruction {
                                        CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
                                TsmmTask task = new TsmmTask(in, _type);
                                Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
-                               sec.setMatrixOutput(output.getName(), 
future_out);
+                               LineageItem li = 
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() : 
null;
+                               sec.setMatrixOutputAndLineage(output.getName(), 
future_out, li);
                        }
                        catch(Exception ex) {
                                throw new DMLRuntimeException(ex);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 8e8d962199..87ddd7b8da 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.lineage;
 
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.lang3.tuple.MutablePair;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.api.DMLScript;
@@ -575,16 +576,13 @@ public class LineageCache
                                        continue;
                                }
 
-                               if (data instanceof MatrixObjectFuture) {
+                               if (data instanceof MatrixObjectFuture || inst 
instanceof PrefetchCPInstruction) {
                                        // We don't want to call get() on the 
future immediately after the execution
+                                       // For the async. instructions, caching 
is handled separately by the tasks
                                        removePlaceholder(item);
                                        continue;
                                }
 
-                               if (inst instanceof PrefetchCPInstruction || 
inst instanceof BroadcastCPInstruction)
-                                       // For the async. instructions, caching 
is handled separately by the tasks
-                                       continue;
-
                                if (data instanceof MatrixObject && 
((MatrixObject) data).hasRDDHandle()) {
                                        // Avoid triggering pre-matured Spark 
instruction chains
                                        removePlaceholder(item);
@@ -643,49 +641,28 @@ public class LineageCache
                }
        }
 
-       public static void putValueAsyncOp(LineageItem instLI, Data data, 
boolean prefetched, long starttime)
+       // This method is called from inside the asynchronous operators and 
directly put the output of
+       // an asynchronous instruction into the lineage cache. As the 
consumers, a different operator,
+       // materializes the intermediate, we skip the placeholder placing logic.
+       public static void putValueAsyncOp(LineageItem instLI, Data data, 
MatrixBlock mb, long starttime)
        {
                if (ReuseCacheType.isNone())
                        return;
-               if (!prefetched) //prefetching was not successful
+               if 
(!ArrayUtils.contains(LineageCacheConfig.getReusableOpcodes(), 
instLI.getOpcode()))
+                       return;
+               if(!(data instanceof MatrixObject) && !(data instanceof 
ScalarObject)) {
                        return;
+               }
 
                synchronized( _cache )
                {
-                       if (!probe(instLI))
-                               return;
-
                        long computetime = System.nanoTime() - starttime;
-                       LineageCacheEntry centry = _cache.get(instLI);
-                       if(!(data instanceof MatrixObject) && !(data instanceof 
ScalarObject)) {
-                               // Reusable instructions can return a frame 
(rightIndex). Remove placeholders.
-                               removePlaceholder(instLI);
-                               return;
-                       }
+                       // Make space, place data and manage queue
+                       putIntern(instLI, DataType.MATRIX, mb, null, 
computetime);
 
-                       MatrixBlock mb = (data instanceof MatrixObject) ?
-                               ((MatrixObject)data).acquireReadAndRelease() : 
null;
-                       long size = mb != null ? mb.getInMemorySize() : 
((ScalarObject)data).getSize();
-
-                       // remove the placeholder if the entry is bigger than 
the cache.
-                       if (size > LineageCacheEviction.getCacheLimit()) {
-                               removePlaceholder(instLI);
-                               return;
-                       }
-
-                       // place the data
-                       if (data instanceof MatrixObject)
-                               centry.setValue(mb, computetime);
-                       else if (data instanceof ScalarObject)
-                               centry.setValue((ScalarObject)data, 
computetime);
-
-                       if (DMLScript.STATISTICS && 
LineageCacheEviction._removelist.containsKey(centry._key)) {
+                       if (DMLScript.STATISTICS && 
LineageCacheEviction._removelist.containsKey(instLI))
                                // Add to missed compute time
-                               
LineageCacheStatistics.incrementMissedComputeTime(centry._computeTime);
-                       }
-
-                       //maintain order for eviction
-                       LineageCacheEviction.addEntry(centry);
+                               
LineageCacheStatistics.incrementMissedComputeTime(computetime);
                }
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 72ea3835a2..fe32f364e5 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -197,6 +197,10 @@ public class LineageCacheConfig
        public static void setReusableOpcodes(String... ops) {
                REUSE_OPCODES = ops;
        }
+
+       public static String[] getReusableOpcodes() {
+               return REUSE_OPCODES;
+       }
        
        public static void resetReusableOpcodes() {
                REUSE_OPCODES = OPCODES;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
index 57d7892b3b..5b49bb82fa 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
@@ -62,11 +62,6 @@ public class LineageReuseSparkTest extends AutomatedTestBase 
{
                runTest(TEST_NAME+"1", ExecMode.SPARK, 1);
        }
 
-       @Test
-       public void testReusePrefetch() {
-               runTest(TEST_NAME+"2", ExecMode.HYBRID, 2);
-       }
-
        public void runTest(String testname, ExecMode execMode, int testId) {
                boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -91,18 +86,10 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        programArgs = proArgs.toArray(new 
String[proArgs.size()]);
 
                        Lineage.resetInternalState();
-                       if (testId == 2) enablePrefetch();
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
-                       disablePrefetch();
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
-                       long numTsmm = 0;
-                       long numMapmm = 0;
-                       if (testId == 1) {
-                               numTsmm = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
-                               numMapmm = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
-                       }
-                       long numPrefetch = 0;
-                       if (testId == 2) numPrefetch = 
Statistics.getCPHeavyHitterCount("prefetch");
+                       long numTsmm = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
+                       long numMapmm = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
 
                        proArgs.clear();
                        proArgs.add("-explain");
@@ -114,18 +101,10 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        programArgs = proArgs.toArray(new 
String[proArgs.size()]);
 
                        Lineage.resetInternalState();
-                       if (testId == 2) enablePrefetch();
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
-                       disablePrefetch();
                        HashMap<MatrixValue.CellIndex, Double> R_reused = 
readDMLScalarFromOutputDir("R");
-                       long numTsmm_r = 0;
-                       long numMapmm_r = 0;
-                       if (testId == 1) {
-                               numTsmm_r = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
-                               numMapmm_r = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
-                       }
-                       long numPrefetch_r = 0;
-                       if (testId == 2) numPrefetch_r = 
Statistics.getCPHeavyHitterCount("prefetch");
+                       long numTsmm_r = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
+                       long numMapmm_r = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
 
                        //compare matrices
                        boolean matchVal = TestUtils.compareMatrices(R, 
R_reused, 1e-6, "Origin", "withPrefetch");
@@ -135,9 +114,6 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                                Assert.assertTrue("Violated sp_tsmm reuse 
count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
                                Assert.assertTrue("Violated sp_mapmm reuse 
count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
                        }
-                       if (testId == 2)
-                               Assert.assertTrue("Violated prefetch reuse 
count: " + numPrefetch_r + " < " + numPrefetch, numPrefetch_r<numPrefetch);
-
                } finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
old_simplification;
                        OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = 
old_sum_product;
@@ -147,16 +123,4 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        Recompiler.reinitRecompiler();
                }
        }
-
-       private void enablePrefetch() {
-               OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
-               OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
-               OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
-       }
-
-       private void disablePrefetch() {
-               OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
-               OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
-               OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
-       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
 b/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
similarity index 67%
copy from 
src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
copy to 
src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
index 57d7892b3b..7666a30184 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
@@ -19,30 +19,29 @@
 
 package org.apache.sysds.test.functions.async;
 
-       import java.util.ArrayList;
-       import java.util.HashMap;
-       import java.util.List;
-
-       import org.apache.sysds.common.Types.ExecMode;
-       import org.apache.sysds.hops.OptimizerUtils;
-       import org.apache.sysds.hops.recompile.Recompiler;
-       import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
-       import org.apache.sysds.runtime.lineage.Lineage;
-       import org.apache.sysds.runtime.lineage.LineageCacheConfig;
-       import org.apache.sysds.runtime.matrix.data.MatrixValue;
-       import org.apache.sysds.test.AutomatedTestBase;
-       import org.apache.sysds.test.TestConfiguration;
-       import org.apache.sysds.test.TestUtils;
-       import org.apache.sysds.utils.Statistics;
-       import org.junit.Assert;
-       import org.junit.Test;
-
-public class LineageReuseSparkTest extends AutomatedTestBase {
-
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.recompile.Recompiler;
+import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ReuseAsyncOpTest extends AutomatedTestBase {
        protected static final String TEST_DIR = "functions/async/";
-       protected static final String TEST_NAME = "LineageReuseSpark";
+       protected static final String TEST_NAME = "ReuseAsyncOp";
        protected static final int TEST_VARIANTS = 2;
-       protected static String TEST_CLASS_DIR = TEST_DIR + 
LineageReuseSparkTest.class.getSimpleName() + "/";
+       protected static String TEST_CLASS_DIR = TEST_DIR + 
ReuseAsyncOpTest.class.getSimpleName() + "/";
 
        @Override
        public void setUp() {
@@ -52,18 +51,14 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
        }
 
        @Test
-       public void testlmdsHB() {
+       public void testReusePrefetch() {
+               // Reuse prefetch results
                runTest(TEST_NAME+"1", ExecMode.HYBRID, 1);
        }
 
        @Test
-       public void testlmdsSP() {
-               // Only reuse the actions
-               runTest(TEST_NAME+"1", ExecMode.SPARK, 1);
-       }
-
-       @Test
-       public void testReusePrefetch() {
+       public void testlmds() {
+               // Reuse future-based tsmm and mapmm
                runTest(TEST_NAME+"2", ExecMode.HYBRID, 2);
        }
 
@@ -91,18 +86,13 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        programArgs = proArgs.toArray(new 
String[proArgs.size()]);
 
                        Lineage.resetInternalState();
-                       if (testId == 2) enablePrefetch();
+                       enableAsync(); //enable max_reuse and prefetch
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
-                       disablePrefetch();
+                       disableAsync();
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
-                       long numTsmm = 0;
-                       long numMapmm = 0;
-                       if (testId == 1) {
-                               numTsmm = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
-                               numMapmm = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
-                       }
-                       long numPrefetch = 0;
-                       if (testId == 2) numPrefetch = 
Statistics.getCPHeavyHitterCount("prefetch");
+                       long numTsmm = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
+                       long numMapmm = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
+                       long numPrefetch = 
Statistics.getCPHeavyHitterCount("prefetch");
 
                        proArgs.clear();
                        proArgs.add("-explain");
@@ -114,28 +104,23 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        programArgs = proArgs.toArray(new 
String[proArgs.size()]);
 
                        Lineage.resetInternalState();
-                       if (testId == 2) enablePrefetch();
+                       enableAsync(); //enable max_reuse and prefetch
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
-                       disablePrefetch();
+                       disableAsync();
                        HashMap<MatrixValue.CellIndex, Double> R_reused = 
readDMLScalarFromOutputDir("R");
-                       long numTsmm_r = 0;
-                       long numMapmm_r = 0;
-                       if (testId == 1) {
-                               numTsmm_r = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
-                               numMapmm_r = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
-                       }
-                       long numPrefetch_r = 0;
-                       if (testId == 2) numPrefetch_r = 
Statistics.getCPHeavyHitterCount("prefetch");
+                       long numTsmm_r = 
Statistics.getCPHeavyHitterCount("sp_tsmm");
+                       long numMapmm_r = 
Statistics.getCPHeavyHitterCount("sp_mapmm");
+                       long numPrefetch_r = 
Statistics.getCPHeavyHitterCount("prefetch");
 
                        //compare matrices
                        boolean matchVal = TestUtils.compareMatrices(R, 
R_reused, 1e-6, "Origin", "withPrefetch");
                        if (!matchVal)
                                System.out.println("Value w/o reuse "+R+" w/ 
reuse "+R_reused);
-                       if (testId == 1) {
+                       if (testId == 2) {
                                Assert.assertTrue("Violated sp_tsmm reuse 
count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
                                Assert.assertTrue("Violated sp_mapmm reuse 
count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
                        }
-                       if (testId == 2)
+                       if (testId == 1)
                                Assert.assertTrue("Violated prefetch reuse 
count: " + numPrefetch_r + " < " + numPrefetch, numPrefetch_r<numPrefetch);
 
                } finally {
@@ -148,13 +133,13 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                }
        }
 
-       private void enablePrefetch() {
+       private void enableAsync() {
                OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
                OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
                OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
        }
 
-       private void disablePrefetch() {
+       private void disableAsync() {
                OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
                OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
                OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
diff --git a/src/test/scripts/functions/async/LineageReuseSpark2.dml 
b/src/test/scripts/functions/async/ReuseAsyncOp1.dml
similarity index 100%
copy from src/test/scripts/functions/async/LineageReuseSpark2.dml
copy to src/test/scripts/functions/async/ReuseAsyncOp1.dml
diff --git a/src/test/scripts/functions/async/LineageReuseSpark2.dml 
b/src/test/scripts/functions/async/ReuseAsyncOp2.dml
similarity index 59%
rename from src/test/scripts/functions/async/LineageReuseSpark2.dml
rename to src/test/scripts/functions/async/ReuseAsyncOp2.dml
index 63792332b4..f4675f69b9 100644
--- a/src/test/scripts/functions/async/LineageReuseSpark2.dml
+++ b/src/test/scripts/functions/async/ReuseAsyncOp2.dml
@@ -18,23 +18,36 @@
 # under the License.
 #
 #-------------------------------------------------------------
-X = rand(rows=10000, cols=200, seed=42); #sp_rand
-v = rand(rows=200, cols=1, seed=42); #cp_rand
 
-# Spark transformation operations 
-for (i in 1:10) {
-  while(FALSE){}
-  sp1 = X + ceil(X);
-  sp2 = sp1 %*% v; #output fits in local
-  # Place a prefetch after mapmm and reuse
+SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y, Double lamda, 
Integer N) return (Matrix[double] beta)
+{
+  # Reuse sp_tsmm and sp_mapmm if not future-based
+  A = (t(X) %*% X) + diag(matrix(lamda, rows=N, cols=1));
+  b = t(X) %*% y;
+  beta = solve(A, b);
+}
+
+no_lamda = 10;
 
-  # CP instructions
-  v2 = ((v + v) * 1 - v) / (1+1);
-  v2 = ((v + v) * 2 - v) / (2+1);
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+lim = 0.1;
 
-  # CP binary triggers the DAG of SP operations
-  cp = sp2 + sum(v2);
-  R = sum(cp);
+X = rand(rows=10000, cols=200, seed=42);
+y = rand(rows=10000, cols=1, seed=43);
+N = ncol(X);
+R = matrix(0, rows=N, cols=no_lamda+2);
+i = 1;
+
+while (lamda < lim)
+{
+  beta = SimlinRegDS(X, y, lamda, N);
+  #beta = lmDS(X=X, y=y, reg=lamda);
+  R[,i] = beta;
+  lamda = lamda + stp;
+  i = i + 1;
 }
 
+R = sum(R);
 write(R, $1, format="text");
+

Reply via email to