This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new 822b492 [SYSTEMDS-335] Updated weighted eviction scheme for lineage
cache
822b492 is described below
commit 822b4922b938ece3a23204823f818545d471bae4
Author: arnabp <[email protected]>
AuthorDate: Tue May 26 21:01:28 2020 +0200
[SYSTEMDS-335] Updated weighted eviction scheme for lineage cache
This patch updates the weighted scheme by adding a elaborate scoring
function. The function has two components, a ratio of compute time,
in-memory size, and a last used timestamp. The components are associated
with weights, which can tune the eviction policies (e.g. weights 0 and 1
for time/size and timestamp respectively translate to LRU scheme). This
patch also replaces the earlier PriorityQueye by a TreeSet.
New eviction test, refactor LineageCacheConfig, eviction logic tuning.
This commit contains,
1) Few updates in eviction logic. Thanks Matthias for catching an
unneeded enqueue/dequeue.
2) Refactoring of LineageCacheConfig class.
3) A new test to compare the order of evicted items based on the
specified policies.
Closes #915.
---
docs/Tasks.txt | 2 +-
.../sysds/runtime/lineage/LineageCacheConfig.java | 154 +++++++++++++--------
.../sysds/runtime/lineage/LineageCacheEntry.java | 11 +-
.../runtime/lineage/LineageCacheEviction.java | 103 +++++---------
.../runtime/lineage/LineageCacheStatistics.java | 5 +-
.../test/functions/dnn/Conv2DBackwardDataTest.java | 3 +-
.../test/functions/dnn/Conv2DBackwardTest.java | 2 +-
.../sysds/test/functions/dnn/Conv2DTest.java | 2 +-
.../sysds/test/functions/dnn/PoolBackwardTest.java | 2 +-
.../apache/sysds/test/functions/dnn/PoolTest.java | 2 +-
.../sysds/test/functions/dnn/ReluBackwardTest.java | 44 ++----
.../test/functions/lineage/CacheEvictionTest.java | 141 +++++++++++++++++++
.../scripts/functions/lineage/CacheEviction1.dml | 55 ++++++++
.../scripts/functions/lineage/LineageReuseAlg3.dml | 2 +-
14 files changed, 357 insertions(+), 171 deletions(-)
diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 081a44b..91c966d 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -270,7 +270,7 @@ SYSTEMDS-330 Lineage Tracing, Reuse and Integration
* 332 Parfor integration with multi-level reuse OK
* 333 Improve cache eviction with actual compute time OK
* 334 Cache scalars only with atleast one matrix inputs
- * 335 Weighted eviction policy (function of size & computetime) OK
+ * 335 Weighted eviction policy (function(size,computetime,LRU time)) OK
* 336 Better use of cache status to handle multithreading
* 337 Adjust disk I/O speed by recording actual time taken OK
* 338 Extended lineage tracing (rmEmpty, lists), partial rewrites OK
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 888d27d..2a3c426 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -26,12 +26,14 @@ import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListIndexingCPInstruction;
-import java.util.ArrayList;
+import java.util.Comparator;
+
+public class LineageCacheConfig
+{
+ //-------------CACHING LOGIC RELATED CONFIGURATIONS--------------//
-public class LineageCacheConfig {
-
private static final String[] REUSE_OPCODES = new String[] {
- "tsmm", "ba+*", "*", "/", "+", "nrow", "ncol",
+ "tsmm", "ba+*", "*", "/", "+", "nrow", "ncol", "round", "exp",
"log",
"rightIndex", "leftIndex", "groupedagg", "r'", "solve", "spoof"
};
@@ -55,63 +57,81 @@ public class LineageCacheConfig {
|| DMLScript.LINEAGE_REUSE == NONE;
}
}
+
+ private static ReuseCacheType _cacheType = null;
+ private static CachedItemHead _itemH = null;
+ private static CachedItemTail _itemT = null;
+ private static boolean _compilerAssistedRW = false;
+
+ //-------------DISK SPILLING RELATED CONFIGURATIONS--------------//
+
+ private static boolean _allowSpill = false;
+ // Minimum reliable spilling estimate in milliseconds.
+ public static final double MIN_SPILL_TIME_ESTIMATE = 100;
+ // Minimum reliable data size for spilling estimate in MB.
+ public static final double MIN_SPILL_DATA = 20;
+ // Default I/O in MB per second for binary blocks
+ public static double FSREAD_DENSE = 200;
+ public static double FSREAD_SPARSE = 100;
+ public static double FSWRITE_DENSE = 150;
+ public static double FSWRITE_SPARSE = 75;
- public enum CachedItemHead {
+ private enum CachedItemHead {
TSMM,
ALL
}
- public enum CachedItemTail {
+ private enum CachedItemTail {
CBIND,
RBIND,
INDEX,
ALL
}
-
- public enum LineageCacheStatus {
- EMPTY, //Placeholder with no data. Cannot be evicted.
- CACHED, //General cached data. Can be evicted.
- EVICTED, //Data is in disk. Empty value. Cannot be
evicted.
- RELOADED, //Reloaded from disk. Can be evicted.
- PINNED; //Pinned to memory. Cannot be evicted.
+
+ //-------------EVICTION RELATED CONFIGURATIONS--------------//
+
+ private static LineageCachePolicy _cachepolicy = null;
+ // Weights for scoring components (computeTime/size, LRU timestamp)
+ private static double[] WEIGHTS = {0, 1};
+
+ protected enum LineageCacheStatus {
+ EMPTY, //Placeholder with no data. Cannot be evicted.
+ CACHED, //General cached data. Can be evicted.
+ EVICTED, //Data is in disk. Empty value. Cannot be evicted.
+ RELOADED, //Reloaded from disk. Can be evicted.
+ PINNED; //Pinned to memory. Cannot be evicted.
public boolean canEvict() {
return this == CACHED || this == RELOADED;
}
- }
-
- public enum LineageCachePolicy {
- LRU,
- WEIGHTED;
- public boolean isLRUcache() {
- return this == LRU;
- }
- }
+ }
- public ArrayList<String> _MMult = new ArrayList<>();
- public static boolean _allowSpill = true;
- // Minimum reliable spilling estimate in milliseconds.
- public static final double MIN_SPILL_TIME_ESTIMATE = 100;
- // Minimum reliable data size for spilling estimate in MB.
- public static final double MIN_SPILL_DATA = 20;
-
- // Default I/O in MB per second for binary blocks
- public static double FSREAD_DENSE = 200;
- public static double FSREAD_SPARSE = 100;
- public static double FSWRITE_DENSE = 150;
- public static double FSWRITE_SPARSE = 75;
+ public enum LineageCachePolicy {
+ LRU,
+ WEIGHTED,
+ HYBRID;
+ }
+
+ protected static Comparator<LineageCacheEntry> LineageCacheComparator =
(e1, e2) -> {
+ // Gather the weights for scoring components
+ double w1 = LineageCacheConfig.WEIGHTS[0];
+ double w2 = LineageCacheConfig.WEIGHTS[1];
+ // Generate scores
+ double score1 = w1*(((double)e1._computeTime)/e1.getSize()) +
w2*e1.getTimestamp();
+ double score2 = w1*((double)e2._computeTime)/e2.getSize() +
w2*e1.getTimestamp();
+ // Generate order. If scores are same, order by LineageItem ID.
+ return score1 == score2 ? Long.compare(e1._key.getId(),
e2._key.getId()) : score1 < score2 ? -1 : 1;
+ };
- private static ReuseCacheType _cacheType = null;
- private static CachedItemHead _itemH = null;
- private static CachedItemTail _itemT = null;
- private static LineageCachePolicy _cachepolicy = null;
- private static boolean _compilerAssistedRW = true;
+ //----------------------------------------------------------------//
static {
//setup static configuration parameters
- setSpill(true); //enable/disable disk spilling.
- setCachePolicy(LineageCachePolicy.WEIGHTED);
+ setSpill(true);
+ //setCachePolicy(LineageCachePolicy.WEIGHTED);
+ setCompAssRW(true);
}
-
+
+
public static boolean isReusable (Instruction inst, ExecutionContext
ec) {
boolean insttype = inst instanceof ComputationCPInstruction
&& !(inst instanceof ListIndexingCPInstruction);
@@ -158,23 +178,6 @@ public class LineageCacheConfig {
DMLScript.LINEAGE = true;
DMLScript.LINEAGE_REUSE = rop;
}
-
- public static void setSpill(boolean toSpill) {
- _allowSpill = toSpill;
- }
-
- public static void setCachePolicy(LineageCachePolicy policy) {
- _cachepolicy = policy;
- }
-
- public static boolean isSetSpill() {
- return _allowSpill;
- }
-
- public static LineageCachePolicy getCachePolicy() {
- return _cachepolicy;
- }
-
public static ReuseCacheType getCacheType() {
return _cacheType;
}
@@ -196,4 +199,37 @@ public class LineageCacheConfig {
return _compilerAssistedRW;
}
+ public static void setCachePolicy(LineageCachePolicy policy) {
+ switch(policy) {
+ case LRU:
+ WEIGHTS[0] = 0; WEIGHTS[1] = 1;
+ break;
+ case WEIGHTED:
+ WEIGHTS[0] = 1; WEIGHTS[1] = 0;
+ break;
+ case HYBRID:
+ WEIGHTS[0] = 1; WEIGHTS[1] = 1;
+ break;
+ }
+ _cachepolicy = policy;
+ }
+
+ public static LineageCachePolicy getCachePolicy() {
+ return _cachepolicy;
+ }
+
+ public static boolean isLRU() {
+ // Check the LRU component of weights array.
+ return (WEIGHTS[1] == 1);
+ }
+
+ public static void setSpill(boolean toSpill) {
+ _allowSpill = toSpill;
+ // NOTE: _allowSpill only enables/disables disk spilling, but
has
+ // no control over eviction order of cached items.
+ }
+
+ public static boolean isSetSpill() {
+ return _allowSpill;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index 9421208..485cac6 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -31,9 +31,8 @@ public class LineageCacheEntry {
protected MatrixBlock _MBval;
protected ScalarObject _SOval;
protected long _computeTime;
+ protected long _timestamp = 0;
protected LineageCacheStatus _status;
- protected LineageCacheEntry _prev;
- protected LineageCacheEntry _next;
protected LineageItem _origItem;
public LineageCacheEntry(LineageItem key, DataType dt, MatrixBlock
Mval, ScalarObject Sval, long computetime) {
@@ -109,4 +108,12 @@ public class LineageCacheEntry {
//resume all threads waiting for val
notifyAll();
}
+
+ protected synchronized void setTimestamp() {
+ _timestamp = System.currentTimeMillis();
+ }
+
+ protected synchronized long getTimestamp() {
+ return _timestamp;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
index f3c2c0e..127e152 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEviction.java
@@ -20,11 +20,10 @@
package org.apache.sysds.runtime.lineage;
import java.io.IOException;
-import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
-import java.util.PriorityQueue;
+import java.util.TreeSet;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
@@ -36,25 +35,14 @@ import org.apache.sysds.runtime.util.LocalFileUtils;
public class LineageCacheEviction
{
- private static LineageCacheEntry _head = null;
- private static LineageCacheEntry _end = null;
private static long _cachesize = 0;
private static long CACHE_LIMIT; //limit in bytes
protected static final HashSet<LineageItem> _removelist = new
HashSet<>();
private static final Map<LineageItem, SpilledItem> _spillList = new
HashMap<>();
private static String _outdir = null;
-
- private static Comparator<LineageCacheEntry> execTime2SizeComparator =
(e1, e2) -> {
- double t2s1 = ((double)e1._computeTime)/e1.getSize();
- double t2s2 = ((double)e2._computeTime)/e2.getSize();
- return t2s1 == t2s2 ? 0 : t2s1 < t2s2 ? -1 : 1;
- };
-
- private static PriorityQueue<LineageCacheEntry> weightedQueue = new
PriorityQueue<>(execTime2SizeComparator);
+ private static TreeSet<LineageCacheEntry> weightedQueue = new
TreeSet<>(LineageCacheConfig.LineageCacheComparator);
protected static void resetEviction() {
- _head = null;
- _end = null;
// reset cache size, otherwise the cache clear leads to
unusable
// space which means evictions could run into endless loops
_cachesize = 0;
@@ -65,11 +53,11 @@ public class LineageCacheEviction
_removelist.clear();
}
- //--------------- CACHE MAINTENANCE & LOOKUP FUNCTIONS ---------//
+ //--------------- CACHE MAINTENANCE & LOOKUP FUNCTIONS --------------//
protected static void addEntry(LineageCacheEntry entry) {
if (entry.isNullVal())
- // Placeholders shouldn't be evicted.
+ // Placeholders shouldn't participate in eviction
cycles.
return;
double exectime = ((double) entry._computeTime) / 1000000; //
in milliseconds
@@ -81,34 +69,30 @@ public class LineageCacheEviction
// will increase chances of multilevel reuse.
entry.setCacheStatus(LineageCacheStatus.PINNED);
- if (LineageCacheConfig.getCachePolicy().isLRUcache()) //LRU
- // Maintain linked list.
- setHead(entry);
- else {
- if (entry.isMatrixValue() || exectime <
LineageCacheConfig.MIN_SPILL_TIME_ESTIMATE)
- // Don't add the memory pinned entries in
weighted queue.
- // The priorityQueue should contain only
entries that can
- // be removed or spilled to disk.
- weightedQueue.add(entry);
+ if (entry.isMatrixValue() || exectime <
LineageCacheConfig.MIN_SPILL_TIME_ESTIMATE) {
+ // Don't add the memory pinned entries in weighted
queue.
+ // The eviction queue should contain only entries that
can
+ // be removed or spilled to disk.
+ entry.setTimestamp();
+ weightedQueue.add(entry);
}
}
protected static void getEntry(LineageCacheEntry entry) {
- if (LineageCacheConfig.getCachePolicy().isLRUcache()) { //LRU
- // maintain linked list.
- delete(entry);
- setHead(entry);
+ // Reset the timestamp to maintain the LRU component of the
scoring function
+ if (!LineageCacheConfig.isLRU())
+ return;
+
+ if (weightedQueue.remove(entry)) {
+ entry.setTimestamp();
+ weightedQueue.add(entry);
}
- // No maintenance is required for weighted scheme
}
protected static void removeEntry(Map<LineageItem, LineageCacheEntry>
cache, LineageItem key) {
if (!cache.containsKey(key))
return;
- if (LineageCacheConfig.getCachePolicy().isLRUcache()) //LRU
- delete(cache.get(key));
- else
- weightedQueue.remove(cache.get(key));
+ weightedQueue.remove(cache.get(key));
cache.remove(key);
}
@@ -116,41 +100,20 @@ public class LineageCacheEviction
if (DMLScript.STATISTICS)
_removelist.add(e._key);
- if (LineageCacheConfig.getCachePolicy().isLRUcache()) //LRU
- delete(e);
_cachesize -= e.getSize();
+ // NOTE: The caller of this method maintains the cache and the
eviction queue.
+
if (DMLScript.STATISTICS)
LineageCacheStatistics.incrementMemDeletes();
}
- private static void delete(LineageCacheEntry entry) {
- if (entry._prev != null)
- entry._prev._next = entry._next;
- else
- _head = entry._next;
- if (entry._next != null)
- entry._next._prev = entry._prev;
- else
- _end = entry._prev;
- }
-
- protected static void setHead(LineageCacheEntry entry) {
- entry._next = _head;
- entry._prev = null;
- if (_head != null)
- _head._prev = entry;
- _head = entry;
- if (_end == null)
- _end = _head;
- }
-
- //---------------- CACHE SPACE MANAGEMENT METHODS -----------------
+ //---------------- CACHE SPACE MANAGEMENT METHODS -----------------//
protected static void setCacheLimit(long limit) {
CACHE_LIMIT = limit;
}
- protected static long getCacheLimit() {
+ public static long getCacheLimit() {
return CACHE_LIMIT;
}
@@ -167,11 +130,8 @@ public class LineageCacheEviction
protected static void makeSpace(Map<LineageItem, LineageCacheEntry>
cache, long spaceNeeded) {
//Cost based eviction
- //TODO better generalization of the different policies (e.g.,
- //_head in below condition is only used when LRU is active)
- boolean isLRU =
LineageCacheConfig.getCachePolicy().isLRUcache();
- LineageCacheEntry e = isLRU ? _end : weightedQueue.poll();
- while (e != _head && e != null)
+ LineageCacheEntry e = weightedQueue.pollFirst();
+ while (e != null)
{
if ((spaceNeeded + _cachesize) <= CACHE_LIMIT)
// Enough space recovered.
@@ -181,16 +141,15 @@ public class LineageCacheEviction
// If eviction is disabled, just delete the
entries.
if (cache.remove(e._key) != null)
removeEntry(cache, e);
- e = isLRU ? e._prev : weightedQueue.poll();
+ e = weightedQueue.pollFirst();
continue;
}
- if (!e.getCacheStatus().canEvict() && isLRU) {
- // Don't delete if the entry's cache status
doesn't allow.
- // Note: no action needed for weightedQueue as
these entries
- // are not part of weightedQueue.
- e = e._prev;
+ if (!e.getCacheStatus().canEvict()) {
+ // Note: Execution should never reach here, as
these
+ // entries are not part of the
weightedQueue.
continue;
+ //TODO: Graceful handling of status.
}
double exectime = ((double) e._computeTime) / 1000000;
// in milliseconds
@@ -200,7 +159,7 @@ public class LineageCacheEviction
// Note: scalar entries with higher computation
time are pinned.
if (cache.remove(e._key) != null)
removeEntry(cache, e);
- e = isLRU ? e._prev : weightedQueue.poll();
+ e = weightedQueue.pollFirst();
continue;
}
@@ -232,7 +191,7 @@ public class LineageCacheEviction
// Remove the entry from cache.
if (cache.remove(e._key) != null)
removeEntry(cache, e);
- e = isLRU ? e._prev : weightedQueue.poll();
+ e = weightedQueue.pollFirst();
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index 7ab7490..55bf70f 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -109,7 +109,10 @@ public class LineageCacheStatistics {
// Number of deletions from cache (including spilling).
_numMemDel.increment();
}
-
+
+ public static long getMemDeletes() {
+ return _numMemDel.longValue();
+ }
public static void incrementFSReadTime(long delta) {
// Total time spent on reading from FS.
diff --git
a/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DBackwardDataTest.java
b/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DBackwardDataTest.java
index 6122e4f..64b1e2d 100644
---
a/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DBackwardDataTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DBackwardDataTest.java
@@ -32,7 +32,6 @@ import org.apache.sysds.test.TestUtils;
public class Conv2DBackwardDataTest extends AutomatedTestBase
{
-
private final static String TEST_NAME = "Conv2DBackwardDataTest";
private final static String TEST_DIR = "functions/tensor/";
private final static String TEST_CLASS_DIR = TEST_DIR +
Conv2DBackwardDataTest.class.getSimpleName() + "/";
@@ -158,7 +157,7 @@ public class Conv2DBackwardDataTest extends
AutomatedTestBase
String sparseVal2 = (""+sparse2).toUpperCase();
long P = DnnUtils.getP(imgSize, filterSize, stride,
pad);
- programArgs = new String[]{"-explain", "-args", "" +
imgSize, "" + numImg,
+ programArgs = new String[]{"-args", "" + imgSize, "" +
numImg,
"" + numChannels, "" + numFilters,
"" + filterSize, "" + stride, "" + pad,
"" + P, "" + P,
diff --git
a/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DBackwardTest.java
b/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DBackwardTest.java
index 0b7d0f4..c901f81 100644
--- a/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DBackwardTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DBackwardTest.java
@@ -198,7 +198,7 @@ public class Conv2DBackwardTest extends AutomatedTestBase
String RI_HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
- programArgs = new String[]{"-explain", "-args",
+ programArgs = new String[]{"-args",
String.valueOf(imgSize),
String.valueOf(numImg),
String.valueOf(numChannels),
String.valueOf(numFilters),
String.valueOf(filterSize),
String.valueOf(stride), String.valueOf(pad),
diff --git a/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DTest.java
b/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DTest.java
index e07942f..ad5b567 100644
--- a/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/dnn/Conv2DTest.java
@@ -269,7 +269,7 @@ public class Conv2DTest extends AutomatedTestBase
String RI_HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
- programArgs = new String[]{"-explain",
"recompile_runtime", "-args",
+ programArgs = new String[] {"recompile_runtime",
"-args",
String.valueOf(imgSize),
String.valueOf(numImg),
String.valueOf(numChannels),
String.valueOf(numFilters),
String.valueOf(filterSize),
String.valueOf(stride), String.valueOf(pad),
diff --git
a/src/test/java/org/apache/sysds/test/functions/dnn/PoolBackwardTest.java
b/src/test/java/org/apache/sysds/test/functions/dnn/PoolBackwardTest.java
index 35efa87..2764801 100644
--- a/src/test/java/org/apache/sysds/test/functions/dnn/PoolBackwardTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/dnn/PoolBackwardTest.java
@@ -155,7 +155,7 @@ public class PoolBackwardTest extends AutomatedTestBase
String RI_HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
- programArgs = new String[]{"-explain", "-args",
String.valueOf(imgSize), String.valueOf(numImg),
+ programArgs = new String[]{"-args",
String.valueOf(imgSize), String.valueOf(numImg),
String.valueOf(numChannels),
String.valueOf(poolSize1), String.valueOf(poolSize2),
String.valueOf(stride),
String.valueOf(pad), String.valueOf(poolMode),
String.valueOf(P), String.valueOf(P),
output("B"), sparseVal1, sparseVal2};
diff --git a/src/test/java/org/apache/sysds/test/functions/dnn/PoolTest.java
b/src/test/java/org/apache/sysds/test/functions/dnn/PoolTest.java
index 105b1e3..bff1001 100644
--- a/src/test/java/org/apache/sysds/test/functions/dnn/PoolTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/dnn/PoolTest.java
@@ -151,7 +151,7 @@ public class PoolTest extends AutomatedTestBase
String RI_HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
- programArgs = new String[]{"-explain", "-args",
String.valueOf(imgSize),
+ programArgs = new String[]{"-args",
String.valueOf(imgSize),
String.valueOf(numImg),
String.valueOf(numChannels),
String.valueOf(poolSize1),
String.valueOf(poolSize2),
String.valueOf(stride), String.valueOf(pad),
poolMode,
diff --git
a/src/test/java/org/apache/sysds/test/functions/dnn/ReluBackwardTest.java
b/src/test/java/org/apache/sysds/test/functions/dnn/ReluBackwardTest.java
index 9eb4866..cee22b9 100644
--- a/src/test/java/org/apache/sysds/test/functions/dnn/ReluBackwardTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/dnn/ReluBackwardTest.java
@@ -38,60 +38,48 @@ public class ReluBackwardTest extends AutomatedTestBase
@Override
public void setUp() {
- addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,
- new String[] {"B"}));
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"}));
}
@Test
- public void testReluBackwardDense1()
- {
+ public void testReluBackwardDense1() {
runReluBackwardTest(ExecType.CP, 10, 100);
}
@Test
- public void testReluBackwardDense2()
- {
+ public void testReluBackwardDense2() {
runReluBackwardTest(ExecType.CP, 100, 10);
}
@Test
- public void testReluBackwardDense3()
- {
+ public void testReluBackwardDense3() {
runReluBackwardTest(ExecType.CP, 100, 100);
}
- /**
- *
- * @param et
- * @param sparse
- */
public void runReluBackwardTest( ExecType et, int M, int N)
{
ExecMode oldRTP = rtplatform;
-
+
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
try
{
- TestConfiguration config = getTestConfiguration(TEST_NAME);
- if(et == ExecType.SPARK) {
- rtplatform = ExecMode.SPARK;
- }
- else {
- rtplatform = ExecMode.SINGLE_NODE;
- }
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
+ if(et == ExecType.SPARK) {
+ rtplatform = ExecMode.SPARK;
+ }
+ else {
+ rtplatform = ExecMode.SINGLE_NODE;
+ }
if( rtplatform == ExecMode.SPARK )
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
loadTestConfiguration(config);
-
- /* This is for running the junit test the new way,
i.e., construct the arguments directly */
+
String RI_HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-args", "" + M, "" + N,
output("B")};
- programArgs = new String[]{"-explain", "-args", "" +
M, "" + N,
- output("B")};
-
boolean exceptionExpected = false;
int expectedNumberOfJobs = -1;
runTest(true, exceptionExpected, null,
expectedNumberOfJobs);
@@ -107,11 +95,9 @@ public class ReluBackwardTest extends AutomatedTestBase
TestUtils.compareMatrices(dmlfile, bHM, epsilon,
"B-DML", "NumPy");
}
- finally
- {
+ finally {
rtplatform = oldRTP;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
-
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java
new file mode 100644
index 0000000..9f925c5
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.lineage;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageCacheEviction;
+import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+public class CacheEvictionTest extends AutomatedTestBase {
+
+ protected static final String TEST_DIR = "functions/lineage/";
+ protected static final String TEST_NAME1 = "CacheEviction1";
+
+ protected String TEST_CLASS_DIR = TEST_DIR +
CacheEvictionTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1));
+ }
+
+ @Test
+ public void testEvictionOrder() {
+ runTest(TEST_NAME1);
+ }
+
+ public void runTest(String testname) {
+ boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ boolean old_sum_product =
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+
+ try {
+ System.out.println("------------ BEGIN " + testname +
"------------");
+
+ /* This test verifies the order of evicted items w.r.t.
the specified
+ * cache policies. This test enables individual
components of the
+ * scoring function by masking the other components,
and compare the
+ * order of evicted entries for different policies.
HYBRID policy is
+ * not considered for this test as it is hard to
anticipate the reuse
+ * statistics if all the components are unmasked.
+ *
+ * TODO: Test disk spilling, which will need some
tunings in eviction
+ * logic; otherwise the automated test might take
significantly
+ * longer as eviction logic tends to just delete
entries with little
+ * computation and estimated I/O time. Note that disk
spilling is
+ * already happening as part of other tests (e.g.
MultiLogReg).
+ */
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false;
+
+ getAndLoadTestConfiguration(testname);
+ fullDMLScriptName = getScript();
+ Lineage.resetInternalState();
+ long cacheSize = LineageCacheEviction.getCacheLimit();
+
+ // LRU based eviction
+ List<String> proArgs = new ArrayList<>();
+ proArgs.add("-stats");
+ proArgs.add("-lineage");
+
proArgs.add(ReuseCacheType.REUSE_FULL.name().toLowerCase());
+ proArgs.add("-args");
+ proArgs.add(String.valueOf(cacheSize));
+ proArgs.add(output("R"));
+ programArgs = proArgs.toArray(new
String[proArgs.size()]);
+
LineageCacheConfig.setCachePolicy(LineageCacheConfig.LineageCachePolicy.LRU);
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+ HashMap<MatrixValue.CellIndex, Double> R_lru =
readDMLMatrixFromHDFS("R");
+ long expCount_lru =
Statistics.getCPHeavyHitterCount("exp");
+ long plusCount_lru =
Statistics.getCPHeavyHitterCount("+");
+ long evictedCount_lru =
LineageCacheStatistics.getMemDeletes();
+
+ // Weighted scheme (computationTime/Size)
+ proArgs.clear();
+ proArgs.add("-stats");
+ proArgs.add("-lineage");
+
proArgs.add(ReuseCacheType.REUSE_FULL.name().toLowerCase());
+ proArgs.add("-args");
+ proArgs.add(String.valueOf(cacheSize));
+ proArgs.add(output("R"));
+ programArgs = proArgs.toArray(new
String[proArgs.size()]);
+ Lineage.resetInternalState();
+
LineageCacheConfig.setCachePolicy(LineageCacheConfig.LineageCachePolicy.WEIGHTED);
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+ HashMap<MatrixValue.CellIndex, Double> R_weighted=
readDMLMatrixFromHDFS("R");
+ long expCount_wt =
Statistics.getCPHeavyHitterCount("exp");
+ long plusCount_wt =
Statistics.getCPHeavyHitterCount("+");
+ long evictedCount_wt =
LineageCacheStatistics.getMemDeletes();
+
+ // Compare results
+ Lineage.setLinReuseNone();
+ TestUtils.compareMatrices(R_lru, R_weighted, 1e-6,
"LRU", "Weighted");
+
+ // Compare reused instructions
+ Assert.assertTrue(expCount_lru > expCount_wt);
+ Assert.assertTrue(plusCount_lru < plusCount_wt);
+
+ // Compare counts of evicted items
+ // LRU tends to evict more entries to recover equal
amount of memory
+ Assert.assertTrue(evictedCount_lru > evictedCount_wt);
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
old_simplification;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES =
old_sum_product;
+ Recompiler.reinitRecompiler();
+ }
+ }
+
+}
diff --git a/src/test/scripts/functions/lineage/CacheEviction1.dml
b/src/test/scripts/functions/lineage/CacheEviction1.dml
new file mode 100644
index 0000000..f25ad6d
--- /dev/null
+++ b/src/test/scripts/functions/lineage/CacheEviction1.dml
@@ -0,0 +1,55 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+cache_size = ceil($1/(1024*1024)); #in MB
+output_size = 8; #8MB
+X = rand(rows=1024, cols=1024, sparsity = 1.0, seed=42);
+X1 = X;
+R = matrix(0, 1024, 1024);
+R1 = R;
+k = floor((cache_size / output_size));
+
+# Fill the cache with 'exp' and '+' outputs
+for (i in 1:k/2) {
+ R = exp(X);
+ X = X + 1;
+}
+
+# Trigger eviction. LRU evicts both 'exp' and '+' results,
+# where Weighted scheme evicts only '+' results to recover
+# same amount of memory.
+for (i in 1:1.5*k/4) {
+ R = round(X);
+ X = X + 1;
+}
+
+
+# Try to reuse 'exp' and '+' results. LRU reuses less
+# 'exp' outputs but more '+' outputs.
+for (i in 1:k/4) {
+ R1 = exp(X1);
+ X1 = X1 + 1;
+}
+
+R = R+R1;
+write(R, $2, format="text");
+
diff --git a/src/test/scripts/functions/lineage/LineageReuseAlg3.dml
b/src/test/scripts/functions/lineage/LineageReuseAlg3.dml
index b5a612d..6b095cb 100644
--- a/src/test/scripts/functions/lineage/LineageReuseAlg3.dml
+++ b/src/test/scripts/functions/lineage/LineageReuseAlg3.dml
@@ -49,7 +49,7 @@ findIcpt = function(Matrix[double] X, Matrix[double] y)
}
-X = rand(rows=1000, cols=100, sparsity=1.0, seed=42);
+X = rand(rows=1000, cols=1000, sparsity=1.0, seed=42);
y = rand(rows=1000, cols=1, min=0, max=6, sparsity=1.0, seed=42);
y = floor(y);