This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new b37026b  [SYSTEMDS-2739] Fix Cost&Size eviction policy
b37026b is described below

commit b37026bac137f780bc9f3fb34887c7dac35a2a5c
Author: arnabp <[email protected]>
AuthorDate: Wed Feb 10 20:15:04 2021 +0100

    [SYSTEMDS-2739] Fix Cost&Size eviction policy
    
    This patch fixes a bug in the logic of adjusting scores
    by cache reference count. In addition to that, this patch
    makes the estimation of saved and missed compute time more
    robust and accurate.
---
 .../apache/sysds/runtime/lineage/LineageCache.java | 30 +++++++++--
 .../sysds/runtime/lineage/LineageCacheEntry.java   | 30 +++++------
 .../runtime/lineage/LineageCacheEviction.java      | 10 ++--
 .../lineage/LineageEstimatorStatistics.java        |  2 +-
 .../sysds/runtime/lineage/LineageRewriteReuse.java | 61 ++++++++++++++++------
 5 files changed, 90 insertions(+), 43 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index d36a1a2..40962ec 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -153,6 +153,9 @@ public class LineageCache
                                        else
                                                ec.setScalarOutput(outName, 
e.getSOValue());
                                        reuse = true;
+
+                                       if (DMLScript.STATISTICS) //increment 
saved time
+                                               
LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
                                }
                                if (DMLScript.STATISTICS)
                                        
LineageCacheStatistics.incrementInstHits();
@@ -172,6 +175,7 @@ public class LineageCache
                        return false;
                
                boolean reuse = (outParams.size() != 0);
+               long savedComputeTime = 0;
                HashMap<String, Data> funcOutputs = new HashMap<>();
                HashMap<String, LineageItem> funcLIs = new HashMap<>();
                for (int i=0; i<numOutputs; i++) {
@@ -211,6 +215,8 @@ public class LineageCache
                                funcOutputs.put(boundVarName, boundValue);
                                LineageItem orig = e._origItem;
                                funcLIs.put(boundVarName, orig);
+                               //all the entries have the same computeTime
+                               savedComputeTime = e._computeTime;
                        }
                        else {
                                // if one output cannot be reused, we need to 
execute the function
@@ -231,6 +237,9 @@ public class LineageCache
                        });
                        //map original lineage items return to the calling site
                        funcLIs.forEach((var, li) -> ec.getLineage().set(var, 
li));
+
+                       if (DMLScript.STATISTICS) //increment saved time
+                               
LineageCacheStatistics.incrementSavedComputeTime(savedComputeTime);
                }
                
                return reuse;
@@ -246,6 +255,7 @@ public class LineageCache
                boolean reuse = false;
                List<Long> outIds = udf.getOutputIds();
                HashMap<String, Data> udfOutputs = new HashMap<>();
+               long savedComputeTime = 0;
 
                //TODO: support multi-return UDFs
                if (udf.getLineageItem(ec) == null)
@@ -278,6 +288,7 @@ public class LineageCache
                                outValue = e.getSOValue();
                        }
                        udfOutputs.put(outName, outValue);
+                       savedComputeTime = e._computeTime;
                        reuse = true;
                }
                else
@@ -298,9 +309,11 @@ public class LineageCache
                                res = LineageItemUtils.setUDFResponse(udf, 
(MatrixObject) val);
                        }
 
-                       if (DMLScript.STATISTICS)
+                       if (DMLScript.STATISTICS) {
                                //TODO: dedicated stats for federated reuse
                                LineageCacheStatistics.incrementInstHits();
+                               
LineageCacheStatistics.incrementSavedComputeTime(savedComputeTime);
+                       }
                        
                        return res;
                }
@@ -323,6 +336,14 @@ public class LineageCache
                }
                return e.getMBValue();
        }
+
+       public static LineageCacheEntry getEntry(LineageItem key) {
+               LineageCacheEntry e = null;
+               synchronized( _cache ) {
+                       e = getIntern(key);
+               }
+               return e;
+       }
        
        //NOTE: safe to pin the object in memory as coming from CPInstruction
        //TODO why do we need both of these public put methods
@@ -545,11 +566,10 @@ public class LineageCache
                // This method is called only when entry is present either in 
cache or in local FS.
                LineageCacheEntry e = _cache.get(key);
                if (e != null && e.getCacheStatus() != 
LineageCacheStatus.SPILLED) {
-                       if (DMLScript.STATISTICS) {
-                               // Increment hit count and saved computation 
time.
+                       if (DMLScript.STATISTICS)
+                               // Increment hit count.
                                LineageCacheStatistics.incrementMemHits();
-                               
LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
-                       }
+
                        // Maintain order for eviction
                        LineageCacheEviction.getEntry(e);
                        return e;
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index a82c8a5..1562d60 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -141,27 +141,23 @@ public class LineageCacheEntry {
        }
        
        protected synchronized void computeScore(Map<LineageItem, Integer> 
removeList) {
+               // Set timestamp and compute initial score
                setTimestamp();
-               if (removeList.containsKey(_key)) {
-                       //FIXME: increase computetime instead of score (that 
now leads to overflow).
-                       // updating computingtime seamlessly takes care of 
spilling 
-                       //_computeTime = _computeTime * (1 + 
removeList.get(_key));
-                       score = score * (1 + removeList.get(_key));
+
+               // Update score to emulate computeTime scaling by #misses
+               if (removeList.containsKey(_key) && 
LineageCacheConfig.isCostNsize()) {
+                       //score = score * (1 + removeList.get(_key));
+                       double w1 = LineageCacheConfig.WEIGHTS[0];
+                       int missCount = 1 + removeList.get(_key);
+                       score = score + (w1*(((double)_computeTime)/getSize()) 
* missCount);
                }
-               if (_computeTime < 0)
-                       System.out.println("after recache: "+_computeTime+" 
miss count: "+removeList.get(_key));
        }
        
-       protected synchronized void updateComputeTime() {
-               if ((Long.MAX_VALUE - _computeTime) < _computeTime) {
-                       System.out.println("Overflow for: "+_key.getOpcode());
-               }
-               //FIXME: increase computetime instead of score (that now leads 
to overflow).
-               // updating computingtime seamlessly takes care of spilling 
-               //_computeTime = _computeTime * (1 + removeList.get(_key));
-               //_computeTime += _computeTime;
-               //recomputeScore();
-               score *= 2;
+       protected synchronized void updateScore() {
+               // Update score to emulate computeTime scaling by cache hit
+               //score *= 2;
+               double w1 = LineageCacheConfig.WEIGHTS[0];
+               score = score + w1*(((double)_computeTime)/getSize());
        }
        
        protected synchronized long getTimestamp() {
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
index e8c4d8d..c64d49b 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
@@ -70,7 +70,8 @@ public class LineageCacheEviction
                        // Don't add the memory pinned entries in weighted 
queue. 
                        // The eviction queue should contain only entries that 
can
                        // be removed or spilled to disk.
-                       //entry.setTimestamp();
+
+                       // Set timestamp, score, and scale score by #misses
                        entry.computeScore(_removelist); 
                        // Adjust score according to cache miss counts.
                        weightedQueue.add(entry);
@@ -85,11 +86,11 @@ public class LineageCacheEviction
                                weightedQueue.add(entry);
                        }
                }
-               // Increase computation time of the sought entry.
+               // Scale score of the sought entry after every cache hit
                // FIXME: avoid when called from partial reuse methods
                if (LineageCacheConfig.isCostNsize()) {
                        if (weightedQueue.remove(entry)) {
-                               entry.updateComputeTime();
+                               entry.updateScore();
                                weightedQueue.add(entry);
                        }
                }
@@ -99,7 +100,7 @@ public class LineageCacheEviction
                if (cache.remove(e._key) != null)
                        _cachesize -= e.getSize();
 
-               // Increase priority if same entry is removed multiple times
+               // Maintain miss count to increase the score if the item enters 
the cache again
                if (_removelist.containsKey(e._key))
                        _removelist.replace(e._key, _removelist.get(e._key)+1);
                else
@@ -224,7 +225,6 @@ public class LineageCacheEviction
                        // Estimate time to write to FS + read from FS.
                        double spilltime = getDiskSpillEstimate(e) * 1000; // 
in milliseconds
                        double exectime = ((double) e._computeTime) / 1000000; 
// in milliseconds
-                       //FIXME: this comuteTime is not adjusted according to 
hit/miss counts
 
                        if (LineageCache.DEBUG) {
                                System.out.print("LI = " + e._key.getOpcode());
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageEstimatorStatistics.java
 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageEstimatorStatistics.java
index 1f6bb11..cb80f0b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageEstimatorStatistics.java
+++ 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageEstimatorStatistics.java
@@ -47,7 +47,7 @@ public class LineageEstimatorStatistics {
        }
        
        public static String displaySize() {
-               //size of all cached reusable intermediates/size of reused 
intermediates//cache size
+               //size of all cached reusable intermediates/size of reused 
intermediates/cache size
                StringBuilder sb = new StringBuilder();
                sb.append(String.format("%.3f", 
((double)LineageEstimator._totReusableSize)/(1024*1024))); //in MB
                sb.append("/");
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
index 60b2eea..e093eb3 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
@@ -72,6 +72,7 @@ public class LineageRewriteReuse
        private static BasicProgramBlock _lrPB = null;
        private static ExecutionContext _lrEC = null;
        private static boolean _disableReuse = true;
+       private static long _computeTime = 0;
        private static final Log LOG = 
LogFactory.getLog(LineageRewriteReuse.class.getName());
        
        public static boolean executeRewrites (Instruction curr, 
ExecutionContext ec)
@@ -120,7 +121,9 @@ public class LineageRewriteReuse
                
ec.setVariable(((ComputationCPInstruction)curr).output.getName(), 
lrwec.getVariable(LR_VAR));
 
                //put the result into the cache
-               LineageCache.putMatrix(curr, ec, t1-t0);
+               //Projected CT(Rewritten entry) = CT(last entry) + CT(rewrite), 
where CT = ComputeTime
+               long totCT = _computeTime + (t1-t0);
+               LineageCache.putMatrix(curr, ec, totCT);
                DMLScript.EXPLAIN = et; //TODO can't change this here
                
                //cleanup execution context
@@ -826,8 +829,10 @@ public class LineageRewriteReuse
                                // create tsmm lineage on top of the input of 
last append
                                LineageItem input1 = source.getInputs()[0];
                                LineageItem tmp = new 
LineageItem(curr.getOpcode(), new LineageItem[] {input1});
-                               if (LineageCache.probe(tmp)) 
+                               if (LineageCache.probe(tmp)) { 
                                        inCache.put("lastMatrix", 
LineageCache.getMatrix(tmp));
+                                       _computeTime = 
LineageCache.getEntry(tmp)._computeTime;
+                               }
                                // look for the old matrix in cache
                                if( LineageCache.probe(input1) )
                                        inCache.put("X", 
LineageCache.getMatrix(input1));
@@ -863,8 +868,10 @@ public class LineageRewriteReuse
                                LineageItem tmp = new 
LineageItem(curr.getOpcode(), new LineageItem[] {input1});
                                if( LineageCache.probe(input1) )
                                        inCache.put("X", 
LineageCache.getMatrix(input1));
-                               if (LineageCache.probe(tmp)) 
+                               if (LineageCache.probe(tmp)) { 
                                        inCache.put("lastMatrix", 
LineageCache.getMatrix(tmp));
+                                       _computeTime += 
LineageCache.getEntry(tmp)._computeTime;
+                               }
                        }
                }
                // return true only if the last tsmm result is found
@@ -884,8 +891,10 @@ public class LineageRewriteReuse
                                // create tsmm lineage on top of the input of 
last append
                                LineageItem input1 = source.getInputs()[0];
                                LineageItem tmp = new 
LineageItem(curr.getOpcode(), new LineageItem[] {input1});
-                               if (LineageCache.probe(tmp)) 
+                               if (LineageCache.probe(tmp)) { 
                                        inCache.put("lastMatrix", 
LineageCache.getMatrix(tmp));
+                                       _computeTime = 
LineageCache.getEntry(tmp)._computeTime;
+                               }
                                // look for the appended column in cache
                                if (source.getInputs().length>1 && 
LineageCache.probe(source.getInputs()[1])) 
                                        inCache.put("deltaX", 
LineageCache.getMatrix(source.getInputs()[1]));
@@ -912,8 +921,10 @@ public class LineageRewriteReuse
                                        LineageItem L2appin1 = 
input.getInputs()[0]; 
                                        LineageItem tmp = new 
LineageItem("cbind", new LineageItem[] {L2appin1, source.getInputs()[1]});
                                        LineageItem toProbe = new 
LineageItem(curr.getOpcode(), new LineageItem[] {tmp});
-                                       if (LineageCache.probe(toProbe)) 
+                                       if (LineageCache.probe(toProbe)) { 
                                                inCache.put("lastMatrix", 
LineageCache.getMatrix(toProbe));
+                                               _computeTime = 
LineageCache.getEntry(toProbe)._computeTime;
+                                       }
                                        // look for the appended column in cache
                                        if 
(LineageCache.probe(input.getInputs()[1])) 
                                                inCache.put("deltaX", 
LineageCache.getMatrix(input.getInputs()[1]));
@@ -951,8 +962,10 @@ public class LineageRewriteReuse
                                        LineageItem old_cbind = new 
LineageItem("cbind", new LineageItem[] {L2appin1, old_RI});
                                        LineageItem tmp = new 
LineageItem("cbind", new LineageItem[] {old_cbind, source.getInputs()[1]});
                                        LineageItem toProbe = new 
LineageItem(curr.getOpcode(), new LineageItem[] {tmp});
-                                       if (LineageCache.probe(toProbe)) 
+                                       if (LineageCache.probe(toProbe)) { 
                                                inCache.put("lastMatrix", 
LineageCache.getMatrix(toProbe));
+                                               _computeTime = 
LineageCache.getEntry(toProbe)._computeTime;
+                                       }
                                }
                        }
                }
@@ -974,8 +987,10 @@ public class LineageRewriteReuse
                                LineageItem leftSource = left.getInputs()[0]; 
//left inpur of rbind = X
                                // create ba+* lineage on top of the input of 
last append
                                LineageItem tmp = new 
LineageItem(curr.getOpcode(), new LineageItem[] {leftSource, right});
-                               if (LineageCache.probe(tmp))
+                               if (LineageCache.probe(tmp)) {
                                        inCache.put("lastMatrix", 
LineageCache.getMatrix(tmp));
+                                       _computeTime = 
LineageCache.getEntry(tmp)._computeTime;
+                               }
                                // look for the appended column in cache
                                if (LineageCache.probe(left.getInputs()[1])) 
                                        inCache.put("deltaX", 
LineageCache.getMatrix(left.getInputs()[1]));
@@ -999,8 +1014,10 @@ public class LineageRewriteReuse
                                LineageItem rightSource = right.getInputs()[0]; 
//left inpur of rbind = X
                                // create ba+* lineage on top of the input of 
last append
                                LineageItem tmp = new 
LineageItem(curr.getOpcode(), new LineageItem[] {left, rightSource});
-                               if (LineageCache.probe(tmp))
+                               if (LineageCache.probe(tmp)) {
                                        inCache.put("lastMatrix", 
LineageCache.getMatrix(tmp));
+                                       _computeTime = 
LineageCache.getEntry(tmp)._computeTime;
+                               }
                                // look for the appended column in cache
                                if (LineageCache.probe(right.getInputs()[1])) 
                                        inCache.put("deltaY", 
LineageCache.getMatrix(right.getInputs()[1]));
@@ -1030,8 +1047,10 @@ public class LineageRewriteReuse
                                        return false;
                                // create ba+* lineage on top of the input of 
last append
                                LineageItem tmp = new 
LineageItem(curr.getOpcode(), new LineageItem[] {left, rightSource1});
-                               if (LineageCache.probe(tmp))
+                               if (LineageCache.probe(tmp)) {
                                        inCache.put("lastMatrix", 
LineageCache.getMatrix(tmp));
+                                       _computeTime = 
LineageCache.getEntry(tmp)._computeTime;
+                               }
                        }
                }
                return inCache.containsKey("lastMatrix") ? true : false;
@@ -1052,8 +1071,10 @@ public class LineageRewriteReuse
                                LineageItem rightSource = right.getInputs()[0]; 
//right inpur of rbind = Y 
                                // create * lineage on top of the input of last 
append
                                LineageItem tmp = new 
LineageItem(curr.getOpcode(), new LineageItem[] {leftSource, rightSource});
-                               if (LineageCache.probe(tmp))
+                               if (LineageCache.probe(tmp)) {
                                        inCache.put("lastMatrix", 
LineageCache.getMatrix(tmp));
+                                       _computeTime = 
LineageCache.getEntry(tmp)._computeTime;
+                               }
                                // look for the appended rows in cache
                                if (LineageCache.probe(left.getInputs()[1]))
                                        inCache.put("deltaX", 
LineageCache.getMatrix(left.getInputs()[1]));
@@ -1079,8 +1100,10 @@ public class LineageRewriteReuse
                                LineageItem rightSource = right.getInputs()[0]; 
//right inpur of cbind = Y 
                                // create * lineage on top of the input of last 
append
                                LineageItem tmp = new 
LineageItem(curr.getOpcode(), new LineageItem[] {leftSource, rightSource});
-                               if (LineageCache.probe(tmp))
+                               if (LineageCache.probe(tmp)) {
                                        inCache.put("lastMatrix", 
LineageCache.getMatrix(tmp));
+                                       _computeTime = 
LineageCache.getEntry(tmp)._computeTime;
+                               }
                                // look for the appended columns in cache
                                if (LineageCache.probe(left.getInputs()[1]))
                                        inCache.put("deltaX", 
LineageCache.getMatrix(left.getInputs()[1]));
@@ -1110,8 +1133,10 @@ public class LineageRewriteReuse
                                LineageItem input1 = target.getInputs()[0];
                                LineageItem tmp = new 
LineageItem(curr.getOpcode(), 
                                                new LineageItem[] {input1, 
groups, weights, fn, ngroups});
-                               if (LineageCache.probe(tmp)) 
+                               if (LineageCache.probe(tmp)) {
                                        inCache.put("lastMatrix", 
LineageCache.getMatrix(tmp));
+                                       _computeTime = 
LineageCache.getEntry(tmp)._computeTime;
+                               }
                                // look for the appended column in cache
                                if (LineageCache.probe(target.getInputs()[1])) 
                                        inCache.put("deltaX", 
LineageCache.getMatrix(target.getInputs()[1]));
@@ -1137,8 +1162,10 @@ public class LineageRewriteReuse
                        LineageItem right = item.getInputs()[1];
                        if (right.getOpcode().equalsIgnoreCase("rightIndex")) {
                                LineageItem indexSource = right.getInputs()[0];
-                               if (LineageCache.probe(indexSource) && 
indexSource.getOpcode().equalsIgnoreCase("ba+*"))
+                               if (LineageCache.probe(indexSource) && 
indexSource.getOpcode().equalsIgnoreCase("ba+*")) {
                                        inCache.put("indexSource", 
LineageCache.getMatrix(indexSource));
+                                       _computeTime = 
LineageCache.getEntry(indexSource)._computeTime;
+                               }
                                LineageItem tmp = new 
LineageItem(item.getOpcode(), new LineageItem[] {left, indexSource});
                                if (LineageCache.probe(tmp))
                                        inCache.put("BigMatMult", 
LineageCache.getMatrix(tmp));
@@ -1160,8 +1187,10 @@ public class LineageRewriteReuse
                                LineageItem src21 = src1.getInputs()[0];
                                LineageItem src22 = src1.getInputs()[1]; //ones
                                if (src21.getOpcode().equalsIgnoreCase("ba+*")) 
{
-                                       if (LineageCache.probe(src21))
+                                       if (LineageCache.probe(src21)) {
                                                inCache.put("projected", 
LineageCache.getMatrix(src21));
+                                               _computeTime = 
LineageCache.getEntry(src21)._computeTime;
+                                       }
                                
                                        LineageItem src31 = 
src21.getInputs()[1];
                                        LineageItem src32 = 
src21.getInputs()[0];
@@ -1174,8 +1203,10 @@ public class LineageRewriteReuse
                                                LineageItem old_ba = new 
LineageItem("ba+*", new LineageItem[] {src32, old_RI});
                                                LineageItem old_cbind = new 
LineageItem("cbind", new LineageItem[] {old_ba, src22});
                                                LineageItem old_tsmm = new 
LineageItem("tsmm", new LineageItem[] {old_cbind});
-                                               if 
(LineageCache.probe(old_tsmm))
+                                               if 
(LineageCache.probe(old_tsmm)) {
                                                        
inCache.put("lastMatrix", LineageCache.getMatrix(old_tsmm));
+                                                       _computeTime += 
LineageCache.getEntry(old_tsmm)._computeTime;
+                                               }
                                        }
                                }
                        }

Reply via email to