This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a070f63b46 [SYSTEMDS-3585] Reuse lineage traces from lineage cache
a070f63b46 is described below
commit a070f63b46611127dbf371f9fdc40730c5619f9c
Author: Arnab Phani <[email protected]>
AuthorDate: Thu Jun 29 13:25:09 2023 +0200
[SYSTEMDS-3585] Reuse lineage traces from lineage cache
This commits adds a small but useful extension to lineage-base reuse.
We now also reuse the lineage traces corresponding to the reused
intermediates by replacing the live lineage traces with the cached ones.
This change increases the use of same lineage items in many lineage
DAGs, which in turn reduces probing cost and memory overhead.
This extension is disabled for parfor and deduplicated lineage traces.
Integrating with those require more thoughts.
Closes #1853
---
.../runtime/controlprogram/ParForProgramBlock.java | 5 +++++
.../controlprogram/context/ExecutionContext.java | 13 +++++++++++++
.../instructions/spark/MapmmChainSPInstruction.java | 16 ++++++++++++++--
.../java/org/apache/sysds/runtime/lineage/Lineage.java | 1 +
.../org/apache/sysds/runtime/lineage/LineageCache.java | 14 +++++++++++---
.../sysds/runtime/lineage/LineageCacheConfig.java | 17 +++++++++++++----
.../sysds/runtime/lineage/LineageCacheStatistics.java | 8 ++++++++
.../org/apache/sysds/runtime/lineage/LineageItem.java | 5 ++++-
src/main/java/org/apache/sysds/utils/Statistics.java | 2 +-
9 files changed, 70 insertions(+), 11 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 4df1a7052e..2fc12c4c26 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -84,6 +84,7 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.meta.DataCharacteristics;
@@ -800,6 +801,7 @@ public class ParForProgramBlock extends ForProgramBlock
StatisticMonitor.putPFStat(_ID,
Stat.PARFOR_INIT_TASKS_T, time.stop());
// Step 3) join all threads (wait for finished work)
+ LineageCacheConfig.setReuseLineageTraces(false);
//disable lineage trace reuse
for( Thread thread : threads )
thread.join();
@@ -823,6 +825,7 @@ public class ParForProgramBlock extends ForProgramBlock
.map(w -> w.getExecutionContext().getLineage())
.toArray(Lineage[]::new);
mergeLineage(ec, lineages);
+ //LineageCacheConfig.setReuseLineageTraces(true);
//consolidate results into global symbol table
consolidateAndCheckResults( ec, numIterations,
numCreatedTasks,
@@ -900,6 +903,7 @@ public class ParForProgramBlock extends ForProgramBlock
exportMatricesToHDFS(ec, brVars);
// Step 3) submit Spark parfor job (no lazy evaluation, since
collect on result)
+ LineageCacheConfig.setReuseLineageTraces(false); //disable
lineage trace reuse
boolean topLevelPF = OptimizerUtils.isTopLevelParFor();
RemoteParForJobReturn ret = RemoteParForSpark.runJob(_ID,
program,
clsMap, tasks, ec, brVars, _resultVars,
_enableCPCaching, _numThreads, topLevelPF);
@@ -913,6 +917,7 @@ public class ParForProgramBlock extends ForProgramBlock
//lineage maintenance
mergeLineage(ec, ret.getLineages());
+ //LineageCacheConfig.setReuseLineageTraces(true);
// TODO: remove duplicate lineage items in ec.getLineage()
//consolidate results into global symbol table
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index cdedfb9e45..9c8547f615 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -167,6 +167,7 @@ public class ExecutionContext {
}
/**
+ *
* Get the i-th GPUContext
* @param index index of the GPUContext
* @return a valid GPUContext or null if the indexed GPUContext does
not exist.
@@ -924,6 +925,18 @@ public class ExecutionContext {
throw new DMLRuntimeException("Lineage Trace
unavailable.");
return _lineage.getOrCreate(input);
}
+
+ public void replaceLineageItem(String varname, LineageItem li) {
+ if (!LineageCacheConfig.isLineageTraceReuse())
+ return;
+ if( _lineage == null )
+ throw new DMLRuntimeException("Lineage Trace
unavailable.");
+ if (_lineage.get(varname) == null)
+ throw new DMLRuntimeException("Lineage item does not
exist for "+varname);
+ //Passed lineage trace should be equivalent to the live lineage
trace
+ //corresponding to varname. Replacing reduces memory and
probing overheads.
+ _lineage.set(varname, li);
+ }
private static String getNonExistingVarError(String varname) {
return "Variable '" + varname + "' does not exist in the symbol
table.";
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
index b1c8248579..e2f4e5d270 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java
@@ -20,9 +20,11 @@
package org.apache.sysds.runtime.instructions.spark;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
+import org.apache.sysds.common.Types;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.lops.MapMultChain.ChainType;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -32,12 +34,15 @@ import
org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import scala.Tuple2;
-public class MapmmChainSPInstruction extends SPInstruction {
+public class MapmmChainSPInstruction extends SPInstruction implements
LineageTraceable {
private ChainType _chainType = null;
private CPOperand _input1 = null;
private CPOperand _input2 = null;
@@ -116,7 +121,14 @@ public class MapmmChainSPInstruction extends SPInstruction
{
//this also includes implicit maintenance of matrix
characteristics
sec.setMatrixOutput(_output.getName(), out);
}
-
+
+ @Override
+ public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+ CPOperand chainT = new CPOperand(_chainType.name(),
Types.ValueType.INT64, Types.DataType.SCALAR, true);
+ return Pair.of(_output.getName(), new LineageItem(getOpcode(),
+ LineageItemUtils.getLineage(ec, _input1, _input2,
_input3, chainT)));
+ }
+
/**
* This function implements the chain type XtXv which requires just one
broadcast and
* no access to any indexes of matrix blocks.
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
index c233e55e23..866ebc8120 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java
@@ -132,6 +132,7 @@ public class Lineage {
public void initializeDedupBlock(ProgramBlock pb, ExecutionContext ec) {
if( !(pb instanceof ForProgramBlock || pb instanceof
WhileProgramBlock) )
throw new DMLRuntimeException("Invalid deduplication
block: "+ pb.getClass().getSimpleName());
+ LineageCacheConfig.setReuseLineageTraces(false);
if (!_dedupBlocks.containsKey(pb)) {
// valid only if doesn't contain a nested loop
boolean valid = LineageDedupUtils.isValidDedupBlock(pb,
false);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index a51c0ae9e3..1a3b12d7a9 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -108,7 +108,8 @@ public class LineageCache
//try to reuse full or partial intermediates
(CPU and FED only)
for (MutablePair<LineageItem,LineageCacheEntry>
item : liList) {
if
(LineageCacheConfig.getCacheType().isFullReuse())
- e =
LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null;
+ //e =
LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null;
+ e = getIntern(item.getKey());
//avoid double probing (containsKey + get)
//TODO need to also move execution of
compensation plan out of here
//(create lazily evaluated entry)
if (e == null &&
LineageCacheConfig.getCacheType().isPartialReuse()
@@ -162,6 +163,7 @@ public class LineageCache
//Even not
persisted, reuse the rdd locally for shuffle operations
if
(!LineageCacheConfig.isShuffleOp(inst))
return
false;
+
((SparkExecutionContext) ec).setRDDHandleForVariable(outName, rdd);
break;
case PERSISTEDRDD:
@@ -184,6 +186,8 @@ public class LineageCache
//Increment the live count for
this pointer
LineageGPUCacheEviction.incrementLiveCount(e.getGPUPointer());
}
+ //Replace the live lineage trace with
the cached one (if not parfor, dedup)
+ ec.replaceLineageItem(outName, e._key);
}
maintainReuseStatistics(ec, inst,
liList.get(0).getValue());
}
@@ -444,6 +448,7 @@ public class LineageCache
if (!p && DMLScript.STATISTICS &&
LineageCacheEviction._removelist.containsKey(key))
// The sought entry was in cache but removed later
LineageCacheStatistics.incrementDelHits();
+
return p;
}
@@ -949,9 +954,11 @@ public class LineageCache
}
private static LineageCacheEntry getIntern(LineageItem key) {
- // This method is called only when entry is present either in
cache or in local FS.
LineageCacheEntry e = _cache.get(key);
- if (e != null && e.getCacheStatus() !=
LineageCacheStatus.SPILLED) {
+ if (e == null)
+ return null;
+
+ if (e.getCacheStatus() != LineageCacheStatus.SPILLED) {
if (DMLScript.STATISTICS)
// Increment hit count.
LineageCacheStatistics.incrementMemHits();
@@ -1222,6 +1229,7 @@ public class LineageCache
//TODO: Replace with generic type
List<MutablePair<LineageItem, LineageCacheEntry>> liList = null;
+ //FIXME: Replace getLineageItem with get/getOrCreate to avoid
creating a new LI object
LineageItem instLI = (cinst != null) ?
cinst.getLineageItem(ec).getValue()
: (cfinst != null) ?
cfinst.getLineageItem(ec).getValue()
: (cspinst != null) ?
cspinst.getLineageItem(ec).getValue()
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index d0e32570b9..63863f7029 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -50,11 +50,11 @@ public class LineageCacheConfig
private static final String[] OPCODES = new String[] {
"tsmm", "ba+*", "*", "/", "+", "||", "nrow", "ncol", "round",
"exp", "log",
"rightIndex", "leftIndex", "groupedagg", "r'", "solve", "spoof",
- "uamean", "max", "min", "ifelse", "-", "sqrt", ">", "uak+",
"<=",
+ "uamean", "max", "min", "ifelse", "-", "sqrt", "<", ">",
"uak+", "<=",
"^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand",
"replace",
- "^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax",
"qsort",
+ "^2", "*2", "uack+", "tak+*", "uacsqk+", "uark+", "n+",
"uarimax", "qsort",
"qpick", "transformapply", "uarmax", "n+", "-*", "castdtm",
"lowertri",
- "prefetch", "mapmm"
+ "prefetch", "mapmm", "contains", "mmchain", "mapmmchain", "+*"
//TODO: Reuse everything.
};
@@ -70,7 +70,7 @@ public class LineageCacheConfig
// Relatively inexpensive instructions.
private static final String[] PERSIST_OPCODES2 = new String[] {
- "mapmm"
+ "mapmm,"
};
private static String[] REUSE_OPCODES = new String[] {};
@@ -104,6 +104,7 @@ public class LineageCacheConfig
private static CachedItemTail _itemT = null;
private static boolean _compilerAssistedRW = false;
private static boolean _onlyEstimate = false;
+ private static boolean _reuseLineageTraces = true;
//-------------DISK SPILLING RELATED CONFIGURATIONS--------------//
@@ -368,6 +369,14 @@ public class LineageCacheConfig
return _compilerAssistedRW;
}
+ public static void setReuseLineageTraces(boolean reuseTrace) {
+ _reuseLineageTraces = reuseTrace;
+ }
+
+ public static boolean isLineageTraceReuse() {
+ return _reuseLineageTraces;
+ }
+
public static void setCachePolicy(LineageCachePolicy policy) {
// TODO: Automatic tuning of weights.
switch(policy) {
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index c7bbd6a00d..fee5b8a0dd 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -41,6 +41,7 @@ public class LineageCacheStatistics {
private static final LongAdder _ctimeFSWrite = new LongAdder();
private static final LongAdder _ctimeSaved = new LongAdder();
private static final LongAdder _ctimeMissed = new LongAdder();
+ private static final LongAdder _ctimeProbe = new LongAdder();
// Bellow entries are specific to gpu lineage cache
private static final LongAdder _numHitsGpu = new LongAdder();
private static final LongAdder _numAsyncEvictGpu= new LongAdder();
@@ -70,6 +71,7 @@ public class LineageCacheStatistics {
_ctimeFSWrite.reset();
_ctimeSaved.reset();
_ctimeMissed.reset();
+ _ctimeProbe.reset();
_evtimeGpu.reset();
_numHitsGpu.reset();
_numAsyncEvictGpu.reset();
@@ -191,6 +193,10 @@ public class LineageCacheStatistics {
_ctimeMissed.add(delta);
}
+ public static void incrementProbeTime(long delta) {
+ _ctimeProbe.add(delta);
+ }
+
public static long getMultiLevelFnHits() {
return _numHitsFunc.longValue();
}
@@ -303,6 +309,8 @@ public class LineageCacheStatistics {
sb.append(String.format("%.3f",
((double)_ctimeSaved.longValue())/1000000000)); //in sec
sb.append("/");
sb.append(String.format("%.3f",
((double)_ctimeMissed.longValue())/1000000000)); //in sec
+ sb.append("/");
+ sb.append(String.format("%.3f",
((double)_ctimeProbe.longValue())/1000000000)); //in sec
return sb.toString();
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
index 311dae2a86..943f497937 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Stack;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -269,8 +270,8 @@ public class LineageItem {
Stack<LineageItem> s2 = new Stack<>();
s1.push(this);
s2.push(that);
- //boolean ret = false;
boolean ret = true;
+ long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
while (!s1.empty() && !s2.empty()) {
LineageItem li1 = s1.pop();
LineageItem li2 = s2.pop();
@@ -356,6 +357,8 @@ public class LineageItem {
}
li1.setVisited();
}
+ if (DMLScript.STATISTICS) //increment probing time
+
LineageCacheStatistics.incrementProbeTime(System.nanoTime() - t0);
return ret;
}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 9d931f3ce3..6978507179 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -649,7 +649,7 @@ public class Statistics
}
sb.append("LinCache writes (Mem/FS/Del): \t" +
LineageCacheStatistics.displayWtrites() + ".\n");
sb.append("LinCache FStimes (Rd/Wr): \t" +
LineageCacheStatistics.displayFSTime() + " sec.\n");
- sb.append("LinCache Computetime (S/M): \t" +
LineageCacheStatistics.displayComputeTime() + " sec.\n");
+ sb.append("LinCache Computetime (S/M/P): \t" +
LineageCacheStatistics.displayComputeTime() + " sec.\n");
sb.append("LinCache Rewrites: \t\t" +
LineageCacheStatistics.displayRewrites() + ".\n");
}