This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 5fd3c65b84 [SYSTEMDS-3599] Cleanup child RDDs, broadcasts from driver 
and executors
5fd3c65b84 is described below

commit 5fd3c65b84438ddbe73137f6f213d2b03e1c4b13
Author: Arnab Phani <[email protected]>
AuthorDate: Mon Jul 17 17:37:47 2023 +0200

    [SYSTEMDS-3599] Cleanup child RDDs, broadcasts from driver and executors
    
    This patch adds methods to clean up the child RDDs of lineage cached RDDs.
    On the first hit, we marked the RDD but let it and its child RDDs get 
cleaned
    up by the rmVar logic. On the second hit, we call persist while putting 
that RDD
    in the cache. On a later hit, if the RDD is already persisted, we clean up 
the
    child RDDs including the checkpointed and broadcast variables. If still not
    persisted, we asynchronously move the RDD to Spark by triggering a job after
    a few local reuse. A future reuse then cleans up the child RDDs.
    
    Closes #1866
---
 .../context/SparkExecutionContext.java             |  27 ++++--
 .../spark/CheckpointSPInstruction.java             |  19 ++++
 .../instructions/spark/data/LineageObject.java     |   4 +
 .../apache/sysds/runtime/lineage/LineageCache.java |  78 +++++++++++-----
 .../sysds/runtime/lineage/LineageCacheEntry.java   |   7 ++
 .../runtime/lineage/LineageCacheEviction.java      |   4 +-
 .../runtime/lineage/LineageSparkCacheEviction.java |  77 ++++++++++++++++
 .../functions/async/LineageReuseSparkTest.java     |  23 ++++-
 .../scripts/functions/async/LineageReuseSpark4.dml |  36 ++++++--
 .../scripts/functions/async/LineageReuseSpark5.dml |  56 ++++++++++++
 .../scripts/functions/async/LineageReuseSpark6.dml | 101 +++++++++++++++++++++
 11 files changed, 385 insertions(+), 47 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index fca6fc0458..aa58f979c9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1503,7 +1503,6 @@ public class SparkExecutionContext extends 
ExecutionContext
                }
        }
 
-       @SuppressWarnings({ "rawtypes", "unchecked" })
        private void rCleanupLineageObject(LineageObject lob)
                throws IOException
        {
@@ -1520,6 +1519,19 @@ public class SparkExecutionContext extends 
ExecutionContext
                if (lob.isInLineageCache())
                        return;
 
+               //cleanup current lineage object (from driver/executors)
+               //incl deferred hdfs file removal (only if metadata set by 
cleanup call)
+               cleanupSingleLineageObject(lob);
+
+               //recursively process lineage children
+               for( LineageObject c : lob.getLineageChilds() ){
+                       c.decrementNumReferences();
+                       rCleanupLineageObject(c);
+               }
+       }
+
+       @SuppressWarnings({ "rawtypes", "unchecked" })
+       public static void cleanupSingleLineageObject(LineageObject lob) {
                //cleanup current lineage object (from driver/executors)
                //incl deferred hdfs file removal (only if metadata set by 
cleanup call)
                if( lob instanceof RDDObject ) {
@@ -1527,7 +1539,12 @@ public class SparkExecutionContext extends 
ExecutionContext
                        int rddID = rdd.getRDD().id();
                        cleanupRDDVariable(rdd.getRDD());
                        if( rdd.getHDFSFilename()!=null ) { //deferred file 
removal
-                               
HDFSTool.deleteFileWithMTDIfExistOnHDFS(rdd.getHDFSFilename());
+                               try {
+                                       
HDFSTool.deleteFileWithMTDIfExistOnHDFS(rdd.getHDFSFilename());
+                               }
+                               catch(IOException e) {
+                                       throw new DMLRuntimeException(e);
+                               }
                        }
                        if( rdd.isParallelizedRDD() )
                                _parRDDs.deregisterRDD(rddID);
@@ -1548,12 +1565,6 @@ public class SparkExecutionContext extends 
ExecutionContext
                        }
                        CacheableData.addBroadcastSize(-bob.getSize());
                }
-
-               //recursively process lineage children
-               for( LineageObject c : lob.getLineageChilds() ){
-                       c.decrementNumReferences();
-                       rCleanupLineageObject(c);
-               }
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
index 16d8c9d228..2be663bdbc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CheckpointSPInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.spark;
 
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.storage.StorageLevel;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.hops.OptimizerUtils;
@@ -41,6 +42,7 @@ import 
org.apache.sysds.runtime.instructions.spark.functions.CopyFrameBlockFunct
 import 
org.apache.sysds.runtime.instructions.spark.functions.CreateSparseBlockFunction;
 import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -99,6 +101,7 @@ public class CheckpointSPInstruction extends 
UnarySPInstruction {
                        //add a dummy entry to the input, which will be 
immediately overwritten by the null output.
                        sec.setVariable( input1.getName(), new 
BooleanObject(false));
                        sec.setVariable( output.getName(), new 
BooleanObject(false));
+                       replaceLineage(ec);
                        return;
                }
 
@@ -107,6 +110,7 @@ public class CheckpointSPInstruction extends 
UnarySPInstruction {
                        // Do nothing if the RDD is already checkpointed
                        sec.setVariable(output.getName(), 
sec.getCacheableData(input1.getName()));
                        Statistics.decrementNoOfExecutedSPInst();
+                       replaceLineage(ec);
                        return;
                }
                //-------
@@ -121,6 +125,7 @@ public class CheckpointSPInstruction extends 
UnarySPInstruction {
                        //available in memory
                        sec.setVariable(output.getName(), obj);
                        Statistics.decrementNoOfExecutedSPInst();
+                       replaceLineage(ec);
                        return;
                }
                
@@ -187,6 +192,7 @@ public class CheckpointSPInstruction extends 
UnarySPInstruction {
                }
                else {
                        out = in; //pass-through
+                       replaceLineage(ec);
                }
                
                // Step 3: In-place update of input matrix/frame rdd handle and 
set as output
@@ -207,5 +213,18 @@ public class CheckpointSPInstruction extends 
UnarySPInstruction {
                        cd.setRDDHandle(outro);
                }
                sec.setVariable( output.getName(), cd);
+               //TODO: remove lineage tracing of chkpoint to allow
+               //  reuse across loops and basic blocks
+               //replaceLineage(ec);
+       }
+
+       private void replaceLineage(ExecutionContext ec) {
+               // Copy the lineage trace of the input to the output
+               // to prevent unnecessary chkpoint lineage entry, which wrongly
+               // reduces reuse opportunities for nested loop bodies.
+               if (DMLScript.LINEAGE) {
+                       LineageItem inputLi = 
ec.getLineageItem(input1.getName());
+                       ec.getLineage().set(output.getName(), inputLi);
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
index 2f9d1bb0bd..146a619a7e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/data/LineageObject.java
@@ -85,4 +85,8 @@ public abstract class LineageObject
                lob.incrementNumReferences();
                _childs.add( lob );
        }
+
+       public void removeAllChild() {
+               _childs.clear();
+       }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 1296e2ee22..557a95d3b8 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -157,15 +157,20 @@ public class LineageCache
                                                //Reuse the cached RDD (local 
or persisted at the executors)
                                                switch(e.getCacheStatus()) {
                                                        case TOPERSISTRDD:
-                                                               //Mark for 
caching on the second hit
-                                                               boolean 
persisted = persistRDD(inst, e, ec);
-                                                               //Return if not 
already persisted and not a shuffle operations
-                                                               if (!persisted 
&& !LineageCacheConfig.isShuffleOp(inst.getOpcode()))
-                                                                       return 
false;
-                                                               //Else, fall 
through to reuse (local or distributed)
+                                                               //Change status 
to PERSISTEDRDD on the second hit
+                                                               //putValueRDD 
method will save the RDD and call persist
+                                                               
e.setCacheStatus(LineageCacheStatus.PERSISTEDRDD);
+                                                               //Cannot reuse 
rdd as already garbage collected
+                                                               return false;
                                                        case PERSISTEDRDD:
                                                                //Reuse the 
persisted intermediate at the executors
                                                                
((SparkExecutionContext) ec).setRDDHandleForVariable(outName, rdd);
+                                                               //Safely 
cleanup the child RDDs if this RDD is persisted already
+                                                               //If reused 3 
times and still not persisted, move to Spark asynchronously
+                                                               if 
(probeRDDDistributed(e))
+                                                                       
LineageSparkCacheEviction.cleanupChildRDDs(e);
+                                                               else
+                                                                       
LineageSparkCacheEviction.moveToSpark(e);
                                                                break;
                                                        default:
                                                                return false;
@@ -295,25 +300,26 @@ public class LineageCache
                                        
LineageGPUCacheEviction.incrementLiveCount(e.getGPUPointer());
                                else if (e.isRDDPersist()) {
                                        //Reuse the cached RDD (local or 
persisted at the executors)
-                                       RDDObject rdd = e.getRDDObject();
                                        switch(e.getCacheStatus()) {
                                                case TOPERSISTRDD:
-                                                       //Mark for caching on 
the second hit
-                                                       long estimatedSize = 
MatrixBlock.estimateSizeInMemory(rdd.getDataCharacteristics());
-                                                       boolean persisted = 
persistRDD(e, estimatedSize);
-                                                       //Return if not already 
persisted and not a shuffle operations
-                                                       if (!persisted && 
!LineageCacheConfig.isShuffleOp(e._origItem.getOpcode()))
-                                                               return false;
-                                                       //Else, fall through to 
reuse (local or distributed)
+                                                       //Cannot reuse rdd as 
already garbage collected
+                                                       //putValue method will 
save the RDD and call persist
+                                                       //while caching the 
original instruction
+                                                       return false;
                                                case PERSISTEDRDD:
                                                        //Reuse the persisted 
intermediate at the executors
+                                                       //Safely cleanup the 
child RDDs if this RDD is persisted already
+                                                       //If reused 3 times and 
still not persisted, move to Spark asynchronously
+                                                       if 
(probeRDDDistributed(e))
+                                                               
LineageSparkCacheEviction.cleanupChildRDDs(e);
+                                                       else
+                                                               
LineageSparkCacheEviction.moveToSpark(e);
                                                        break;
                                                default:
                                                        return false;
                                        }
                                }
                        }
-
                        funcOutputs.forEach((var, val) -> {
                                //cleanup existing data bound to output 
variable name
                                Data exdata = ec.removeVariable(var);
@@ -328,7 +334,6 @@ public class LineageCache
                        if (DMLScript.STATISTICS) //increment saved time
                                
LineageCacheStatistics.incrementSavedComputeTime(savedComputeTime);
                }
-
                return reuse;
        }
        
@@ -492,7 +497,10 @@ public class LineageCache
        private static boolean probeRDDDistributed(LineageItem key) {
                if (!_cache.containsKey(key))
                        return false;
-               LineageCacheEntry e = _cache.get(key);
+               return probeRDDDistributed(_cache.get(key));
+       }
+
+       protected static boolean probeRDDDistributed(LineageCacheEntry e) {
                if (!e.isRDDPersist())
                        return false;
                return 
SparkExecutionContext.isRDDCached(e.getRDDObject().getRDD().id());
@@ -716,8 +724,9 @@ public class LineageCache
                        if (!probe(instLI))
                                return;
                        LineageCacheEntry centry = _cache.get(instLI);
-                       // Put in the cache only the first time
-                       if (centry.getCacheStatus() != LineageCacheStatus.EMPTY)
+                       // Remember the 1st hit and put the RDD in the cache 
the 2nd time
+                       if (centry.getCacheStatus() != LineageCacheStatus.EMPTY 
           //first hit
+                               && centry.getCacheStatus() != 
LineageCacheStatus.PERSISTEDRDD) //second hit
                                return;
                        // Avoid reuse chkpoint, which is unnecessary
                        if (inst.getOpcode().equalsIgnoreCase("chkpoint")) {
@@ -742,10 +751,29 @@ public class LineageCache
                        // Get the RDD handle of the RDD
                        CacheableData<?> cd = 
ec.getCacheableData(((ComputationSPInstruction)inst).output.getName());
                        RDDObject rddObj = cd.getRDDHandle();
-                       // Set the RDD object in the cache and set the status 
to TOPERSISTRDD
-                       rddObj.setLineageCached();
+                       // Save the metadata. Required for estimating cached 
space overhead.
                        
rddObj.setDataCharacteristics(cd.getDataCharacteristics());
-                       centry.setRDDValue(rddObj, computetime);
+                       // Set the RDD object in the cache
+                       switch(centry.getCacheStatus()) {
+                               case EMPTY:  //first hit
+                                       // Do not save the child RDDS (incl. 
broadcast vars) on the first hit.
+                                       // Let them be garbage collected via 
rmvar. Save them on the second hit
+                                       // by disabling garbage collection on 
this and the child RDDs.
+                                       centry.setRDDValue(rddObj, 
computetime); //rddObj will be garbage collected
+                                       break;
+                               case PERSISTEDRDD:  //second hit
+                                       // Replace the old RDD (GCed) with the 
new one
+                                       centry.setRDDValue(rddObj);
+                                       // Set the correct status to indicate 
the RDD is marked to be persisted
+                                       
centry.setCacheStatus(LineageCacheStatus.PERSISTEDRDD);
+                                       // Call persist. Next collect will 
materialize this intermediate in Spark
+                                       persistRDD(inst, centry, ec);
+                                       // Mark lineage cached to prevent this 
and child RDDs from cleanup by rmvar
+                                       
centry.getRDDObject().setLineageCached();
+                                       break;
+                               default:
+                                       throw new 
DMLRuntimeException("Execution should not reach here: "+centry._key);
+                       }
                }
        }
 
@@ -1017,6 +1045,8 @@ public class LineageCache
                        return;
                // Move the value from the cache entry with key probeItem to
                // the placeholder entry with key item.
+               // Entries with RDDs are cached twice. First hit is GCed,
+               // Second hit saves the child RDDs
                if (LineageCache.probe(probeItem)) {
                        LineageCacheEntry oe = getIntern(probeItem);
                        LineageCacheEntry e = _cache.get(item);
@@ -1083,7 +1113,7 @@ public class LineageCache
        private static boolean persistRDD(Instruction inst, LineageCacheEntry 
centry, ExecutionContext ec) {
                // If already persisted, change the status and return true.
                // Else, persist, change cache status and return false.
-               if (probeRDDDistributed(centry._key)) {
+               if (probeRDDDistributed(centry)) {
                        // Update status to indicate persisted in the executors
                        centry.setCacheStatus(LineageCacheStatus.PERSISTEDRDD);
                        return true;
@@ -1103,7 +1133,7 @@ public class LineageCache
        private static boolean persistRDD(LineageCacheEntry centry, long 
estimatedSize) {
                // If already persisted, change the status and return true.
                // Else, persist, change cache status and return false.
-               if (probeRDDDistributed(centry._key)) {
+               if (probeRDDDistributed(centry)) {
                        // Update status to indicate persisted in the executors
                        centry.setCacheStatus(LineageCacheStatus.PERSISTEDRDD);
                        return true;
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index 4f3a3354ab..cc66b321a0 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -240,6 +240,13 @@ public class LineageCacheEntry {
                notifyAll();
        }
 
+       public synchronized void setRDDValue(RDDObject rdd) {
+               _rddObject = rdd;
+               _status = isNullVal() ? LineageCacheStatus.EMPTY : 
LineageCacheStatus.TOPERSISTRDD;
+               //resume all threads waiting for val
+               notifyAll();
+       }
+
        public synchronized void setValue(byte[] serialBytes, long computetime) 
{
                _serialBytes = serialBytes;
                _computeTime = computetime;
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
index 9d642c8afe..1a9eece8c8 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
@@ -58,9 +58,9 @@ public class LineageCacheEviction
                        return;
 
                double exectime = ((double) entry._computeTime) / 1000000; // 
in milliseconds
-               if (!entry.isMatrixValue() && exectime >= 
LineageCacheConfig.MIN_SPILL_TIME_ESTIMATE)
+               if (!entry.isMatrixValue() && (exectime >= 
LineageCacheConfig.MIN_SPILL_TIME_ESTIMATE || entry._origItem != null))
                        // Pin the entries having scalar values and with higher 
computation time
-                       // to memory, to save those from eviction. Scalar 
values are
+                       // to memory or function output, to save those from 
eviction. Scalar values are
                        // not spilled to disk and are just deleted. Scalar 
entries associated 
                        // with high computation time might contain function 
outputs. Pinning them
                        // will increase chances of multilevel reuse.
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
index 6444df89d8..8dfbc44032 100644
--- 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageSparkCacheEviction.java
@@ -22,20 +22,28 @@ package org.apache.sysds.runtime.lineage;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.instructions.spark.data.LineageObject;
+import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCacheStatus;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Map;
 import java.util.TreeSet;
+import java.util.concurrent.Executors;
 
 public class LineageSparkCacheEviction
 {
        private static long SPARK_STORAGE_LIMIT = 0; //60% (upper limit of 
Spark unified memory)
        private static long _sparkStorageSize = 0; //current size
        private static TreeSet<LineageCacheEntry> weightedQueue = new 
TreeSet<>(LineageCacheConfig.LineageCacheComparator);
+       protected static final Map<LineageItem, Integer> RDDHitCountLocal = new 
HashMap<>();
 
        protected static void resetEviction() {
                _sparkStorageSize = 0;
                weightedQueue.clear();
+               RDDHitCountLocal.clear();
        }
 
        //--------------- CACHE MAINTENANCE & LOOKUP FUNCTIONS --------------//
@@ -155,4 +163,73 @@ public class LineageSparkCacheEviction
                        removeEntry(cache, e);
                }
        }
+
+       //---------------- LOCAL CLEANUP METHODS -----------------//
+
+       protected static void cleanupChildRDDs(LineageCacheEntry e) {
+               if (e.getCacheStatus() == LineageCacheStatus.PERSISTEDRDD) {
+                       // Persisted at Spark. Cleanup the child RDDs and 
broadcast vars
+                       for( LineageObject c : 
e.getRDDObject().getLineageChilds() ){
+                               c.decrementNumReferences();
+                               rCleanupChildRDDs(c);
+                       }
+                       // Also detach the child RDDs to be GCed
+                       e.getRDDObject().removeAllChild();
+               }
+               // TODO: Cleanup the child RDDs of the persisted RDDs
+               //  which are never reused after the second hit.
+       }
+
+       protected static void rCleanupChildRDDs(LineageObject lob) {
+               // Abort recursive cleanup if still consumers
+               if( lob.getNumReferences() > 0 )
+                       return;
+
+               // Abort if still reachable through live matrix object
+               if( lob.hasBackReference() )
+                       return;
+
+               // Abort if the RDD is yet to be persisted
+               if (lob instanceof RDDObject && lob.isInLineageCache()
+                       && 
SparkExecutionContext.isRDDCached(((RDDObject)lob).getRDD().id()))
+                       return;
+
+               // Cleanup current lineage object (from driver/executors)
+               SparkExecutionContext.cleanupSingleLineageObject(lob);
+
+               //recursively process lineage children
+               for (LineageObject c : lob.getLineageChilds()) {
+                       c.decrementNumReferences();
+                       rCleanupChildRDDs(c);
+               }
+       }
+
+       // RDDs that are marked for persistence, reused more than three times,
+       // but never actually persisted in the executors. Asynchronously move
+       // them to Spark by triggering a Spark job. The next reuse will clean up
+       // The next reuse will clean up the child RDDs and broadcast variables.
+       protected static void moveToSpark(LineageCacheEntry e) {
+               RDDHitCountLocal.merge(e._key, 1, Integer::sum);
+               int localHitCount = RDDHitCountLocal.get(e._key);
+               if (localHitCount > 3) {
+                       RDDHitCountLocal.remove(e._key);
+                       if (CommonThreadPool.triggerRemoteOPsPool == null)
+                               CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
+                       CommonThreadPool.triggerRemoteOPsPool.submit(new 
TriggerRemoteTask(e.getRDDObject().getRDD()));
+               }
+       }
+
+       private static class TriggerRemoteTask implements Runnable {
+               JavaPairRDD<?, ?> rdd;
+
+               public TriggerRemoteTask(JavaPairRDD<?,?> persistRDD) {
+                       rdd = persistRDD;
+               }
+
+               @Override
+               public void run() {
+                       // Trigger a Spark job
+                       long ret = rdd.count();
+               }
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
index d9200689fb..f2cf085838 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
@@ -29,6 +29,7 @@ import org.apache.sysds.hops.recompile.Recompiler;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.lineage.Lineage;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
 import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -41,7 +42,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
 
        protected static final String TEST_DIR = "functions/async/";
        protected static final String TEST_NAME = "LineageReuseSpark";
-       protected static final int TEST_VARIANTS = 4;
+       protected static final int TEST_VARIANTS = 6;
        protected static String TEST_CLASS_DIR = TEST_DIR + 
LineageReuseSparkTest.class.getSimpleName() + "/";
 
        @Override
@@ -70,7 +71,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
 
        @Test
        public void testL2svm() {
-               runTest(TEST_NAME+"3", ExecMode.SPARK, 3);
+               runTest(TEST_NAME+"3", ExecMode.HYBRID, 3);
        }
 
        @Test
@@ -79,6 +80,17 @@ public class LineageReuseSparkTest extends AutomatedTestBase 
{
                runTest(TEST_NAME+"4", ExecMode.HYBRID, 4);
        }
 
+       @Test
+       public void testEnsemble() {
+               runTest(TEST_NAME+"5", ExecMode.HYBRID, 5);
+       }
+
+       //FIXME: Collecting a persisted RDD still needs the broadcast vars. 
Debug.
+       /*@Test
+       public void testHyperband() {
+               runTest(TEST_NAME+"6", ExecMode.HYBRID, 6);
+       }*/
+
        public void runTest(String testname, ExecMode execMode, int testId) {
                boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -130,12 +142,17 @@ public class LineageReuseSparkTest extends 
AutomatedTestBase {
                        boolean matchVal = TestUtils.compareMatrices(R, 
R_reused, 1e-6, "Origin", "withPrefetch");
                        if (!matchVal)
                                System.out.println("Value w/o reuse "+R+" w/ 
reuse "+R_reused);
-                       if (testId == 1 || testId == 3) {
+                       if (testId == 1) {
                                Assert.assertTrue("Violated sp_tsmm reuse 
count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
                                Assert.assertTrue("Violated sp_mapmm reuse 
count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
                        }
+                       if (testId == 3)
+                               Assert.assertTrue("Violated sp_mapmm reuse 
count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
                        if (testId == 2)
                                Assert.assertTrue("Violated sp_rmm reuse count: 
" + numRmm_r + " < " + numRmm, numRmm_r < numRmm);
+                       if (testId == 4 || testId == 5) { // fn/SB reuse
+                               
Assert.assertTrue((LineageCacheStatistics.getMultiLevelFnHits() + 
LineageCacheStatistics.getMultiLevelSBHits()) > 1);
+                       }
                } finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
old_simplification;
                        OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = 
old_sum_product;
diff --git a/src/test/scripts/functions/async/LineageReuseSpark4.dml 
b/src/test/scripts/functions/async/LineageReuseSpark4.dml
index 90f270c3ba..18cff0cf34 100644
--- a/src/test/scripts/functions/async/LineageReuseSpark4.dml
+++ b/src/test/scripts/functions/async/LineageReuseSpark4.dml
@@ -19,16 +19,16 @@
 #
 #-------------------------------------------------------------
 
-SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y, Double lamda, 
Integer N) 
+SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y) 
 return (Matrix[double] A, Matrix[double] b)
 {
   # Reuse sp_tsmm and sp_mapmm if not future-based
-  A = (t(X) %*% X) + diag(matrix(lamda, rows=N, cols=1));
+  A = (t(X) %*% X); 
   while(FALSE){}
   b = t(X) %*% y;
 }
 
-no_lamda = 2;
+no_lamda = 5;
 
 stp = (0.1 - 0.0001)/no_lamda;
 lamda = 0.0001;
@@ -38,19 +38,35 @@ X = rand(rows=1500, cols=1500, seed=42);
 y = rand(rows=1500, cols=1, seed=43);
 N = ncol(X);
 R = matrix(0, rows=N, cols=no_lamda+2);
+i = 1;
 
-[A, b] = SimlinRegDS(X, y, lamda, N);
-beta = solve(A, b);
+while (lamda < lim)
+{
+  [A, b] = SimlinRegDS(X, y);
+  A_diag = A + diag(matrix(lamda, rows=N, cols=1));
+  beta = solve(A_diag, b);
+  R[,i] = beta;
+  lamda = lamda + stp;
+  i = i + 1;
+}
+
+/*[A, b] = SimlinRegDS(X, y);
+A_diag = A + diag(matrix(lamda, rows=N, cols=1));
+beta = solve(A_diag, b);
 R[,1] = beta;
+lamda = lamda + stp;
 
 # Reuse function call
-[A, b] = SimlinRegDS(X, y, lamda, N);
-beta = solve(A, b);
+[A, b] = SimlinRegDS(X, y);
+A_diag = A + diag(matrix(lamda, rows=N, cols=1));
+beta = solve(A_diag, b);
 R[,2] = beta;
+lamda = lamda + stp;
 
-[A, b] = SimlinRegDS(X, y, lamda, N);
-beta = solve(A, b);
-R[,3] = beta;
+[A, b] = SimlinRegDS(X, y);
+A_diag = A + diag(matrix(lamda, rows=N, cols=1));
+beta = solve(A_diag, b);
+R[,3] = beta;*/
 
 R = sum(R);
 write(R, $1, format="text");
diff --git a/src/test/scripts/functions/async/LineageReuseSpark5.dml 
b/src/test/scripts/functions/async/LineageReuseSpark5.dml
new file mode 100644
index 0000000000..6d6245c588
--- /dev/null
+++ b/src/test/scripts/functions/async/LineageReuseSpark5.dml
@@ -0,0 +1,56 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+M = 10000;
+N = 200;
+sp = 1.0; #1.0
+nweights = 10; #3000
+
+X = rand(rows=M, cols=N, sparsity=sp, seed=42);
+y = rand(rows=M, cols=1, min=0, max=2, seed=42);
+y = ceil(y);
+
+model_svm = l2svm(X=X, Y=y, intercept=TRUE, epsilon=1e-12,
+ reg=0.001, maxIterations=20, verbose=FALSE);
+model_mlr = multiLogReg(X=X, Y=y, icpt=2, tol=1e-6, reg=0.001, maxi=20, 
maxii=20, verbose=FALSE);
+
+# Assign random weights and grid search top-k models
+bestAcc = 0;
+weights = rand(rows=2, cols=nweights, min=0, max=1, seed=42);
+nclass = 2;
+k = 2;
+for (wi in 1:nweights) {
+  weightedClassProb = matrix(0, M, 2);
+  for (i in 1:k) {
+    [yRaw, yPred] = l2svmPredict(X=X, W=model_svm, verbose=FALSE);
+    probs_svm = yRaw / rowSums(yRaw);
+    [prob_mlr, Y_mlr, acc] = multiLogRegPredict(X=X, B=model_mlr, Y=y, 
verbose=FALSE);
+    weightedClassProb = weightedClassProb + as.scalar(weights[1,wi])*probs_svm 
+ as.scalar(weights[2,wi])*prob_mlr;
+    y_voted = rowIndexMax(weightedClassProb);
+    acc = sum(y_voted == y) / M * 100;
+    if (acc > bestAcc) {
+      bestWeights = weights;
+      bestAcc = acc;
+    }
+  }
+}
+R = bestAcc;
+write(R, $1, format="text");
+
diff --git a/src/test/scripts/functions/async/LineageReuseSpark6.dml 
b/src/test/scripts/functions/async/LineageReuseSpark6.dml
new file mode 100644
index 0000000000..e1b253a8cf
--- /dev/null
+++ b/src/test/scripts/functions/async/LineageReuseSpark6.dml
@@ -0,0 +1,101 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+randRegSample = function(Matrix[Double] lamdas, Double ratio)
+return (Matrix[Double] samples) {
+  temp = rand(rows=nrow(lamdas), cols=1, min=0, max=1, seed=42) < ratio;
+  samples = removeEmpty(target=lamdas, margin="rows", select=temp);
+}
+
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B)
+return (Double accuracy) {
+  #loss = as.matrix(sum((y - X%*%B)^2));
+  [yRaw, yPred] = l2svmPredict(X=X, W=B, verbose=FALSE);
+  accuracy = sum((yPred - y) == 0) / nrow(y) * 100;
+}
+
+M = 10000;
+N = 200;
+sp = 1.0; #1.0
+no_bracket = 2; #5
+
+X = rand(rows=M, cols=N, sparsity=sp, seed=42);
+y = rand(rows=M, cols=1, min=0, max=2, seed=42);
+y = ceil(y);
+
+no_lamda = 25; #starting combintaions = 25 * 4 = 100 HPs
+stp = (0.1 - 0.0001)/no_lamda;
+HPlamdas = seq(0.0001, 0.1, stp);
+maxIter = 10; #starting interation count = 100 * 10 = 1k
+
+for (r in 1:no_bracket) {
+  i = 1;
+  svmModels = matrix(0, rows=no_lamda, cols=ncol(X)+2); #first col is accuracy
+  mlrModels = matrix(0, rows=no_lamda, cols=ncol(X)+2); #first col is accuracy
+  # Optimize for regularization parameters
+  print("#lamda = "+no_lamda+", maxIterations = "+maxIter);
+  for (l in 1:no_lamda)
+  {
+    #print("lamda = "+as.scalar(HPlamdas[i,1])+", maxIterations = "+maxIter);
+    #Run L2svm with intercept true
+    beta = l2svm(X=X, Y=y, intercept=TRUE, epsilon=1e-12,
+      reg = as.scalar(HPlamdas[i,1]), maxIterations=maxIter, verbose=FALSE);
+    svmModels[i,1] = l2norm(X, y, beta); #1st column
+    svmModels[i,2:nrow(beta)+1] = t(beta);
+
+    #Run L2svm with intercept false
+    beta = l2svm(X=X, Y=y, intercept=FALSE, epsilon=1e-12,
+      reg = as.scalar(HPlamdas[i,1]), maxIterations=maxIter, verbose=FALSE);
+    svmModels[i,1] = l2norm(X, y, beta); #1st column
+    svmModels[i,2:nrow(beta)+1] = t(beta);
+
+    #Run multilogreg with intercept true
+    beta = multiLogReg(X=X, Y=y, icpt=2, tol=1e-6, 
reg=as.scalar(HPlamdas[i,1]),
+      maxi=maxIter, maxii=20, verbose=FALSE);
+    [prob_mlr, Y_mlr, acc] = multiLogRegPredict(X=X, B=beta, Y=y, 
verbose=FALSE);
+    mlrModels[i,1] = acc; #1st column
+    mlrModels[i,2:nrow(beta)+1] = t(beta);
+
+    #Run multilogreg with intercept false
+    beta = multiLogReg(X=X, Y=y, icpt=1, tol=1e-6, 
reg=as.scalar(HPlamdas[i,1]),
+      maxi=maxIter, maxii=20, verbose=FALSE);
+    [prob_mlr, Y_mlr, acc] = multiLogRegPredict(X=X, B=beta, Y=y, 
verbose=FALSE);
+    mlrModels[i,1] = acc; #1st column
+    mlrModels[i,2:nrow(beta)+1] = t(beta);
+
+    i = i + 1;
+  }
+  #Sort the models based on accuracy
+  svm_order = order(target=svmModels, by=1);
+  bestAccSvm = svm_order[1,1];
+  print(toString(bestAccSvm));
+  mlr_order = order(target=mlrModels, by=1);
+  bestAccMlr = mlr_order[1,1];
+  print(toString(bestAccMlr));
+
+  #double the iteration count and half the HPs
+  maxIter = maxIter * 2;
+  HPlamdas = randRegSample(HPlamdas, 0.5);
+  #TODO: select the models with highest accruacies
+  no_lamda = nrow(HPlamdas);
+}
+R = sum(bestAccSvm) + sum(bestAccMlr);
+write(R, $1, format="text");
+

Reply via email to