This is an automated email from the ASF dual-hosted git repository. arnabp20 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit a070f63b46611127dbf371f9fdc40730c5619f9c Author: Arnab Phani <[email protected]> AuthorDate: Thu Jun 29 13:25:09 2023 +0200 [SYSTEMDS-3585] Reuse lineage traces from lineage cache This commits adds a small but useful extension to lineage-base reuse. We now also reuse the lineage traces corresponding to the reused intermediates by replacing the live lineage traces with the cached ones. This change increases the use of same lineage items in many lineage DAGs, which in turn reduces probing cost and memory overhead. This extension is disabled for parfor and deduplicated lineage traces. Integrating with those require more thoughts. Closes #1853 --- .../runtime/controlprogram/ParForProgramBlock.java | 5 +++++ .../controlprogram/context/ExecutionContext.java | 13 +++++++++++++ .../instructions/spark/MapmmChainSPInstruction.java | 16 ++++++++++++++-- .../java/org/apache/sysds/runtime/lineage/Lineage.java | 1 + .../org/apache/sysds/runtime/lineage/LineageCache.java | 14 +++++++++++--- .../sysds/runtime/lineage/LineageCacheConfig.java | 17 +++++++++++++---- .../sysds/runtime/lineage/LineageCacheStatistics.java | 8 ++++++++ .../org/apache/sysds/runtime/lineage/LineageItem.java | 5 ++++- src/main/java/org/apache/sysds/utils/Statistics.java | 2 +- 9 files changed, 70 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java index 4df1a7052e..2fc12c4c26 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java @@ -84,6 +84,7 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.cp.StringObject; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.lineage.Lineage; +import org.apache.sysds.runtime.lineage.LineageCacheConfig; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.lineage.LineageItemUtils; import org.apache.sysds.runtime.meta.DataCharacteristics; @@ -800,6 +801,7 @@ public class ParForProgramBlock extends ForProgramBlock StatisticMonitor.putPFStat(_ID, Stat.PARFOR_INIT_TASKS_T, time.stop()); // Step 3) join all threads (wait for finished work) + LineageCacheConfig.setReuseLineageTraces(false); //disable lineage trace reuse for( Thread thread : threads ) thread.join(); @@ -823,6 +825,7 @@ public class ParForProgramBlock extends ForProgramBlock .map(w -> w.getExecutionContext().getLineage()) .toArray(Lineage[]::new); mergeLineage(ec, lineages); + //LineageCacheConfig.setReuseLineageTraces(true); //consolidate results into global symbol table consolidateAndCheckResults( ec, numIterations, numCreatedTasks, @@ -900,6 +903,7 @@ public class ParForProgramBlock extends ForProgramBlock exportMatricesToHDFS(ec, brVars); // Step 3) submit Spark parfor job (no lazy evaluation, since collect on result) + LineageCacheConfig.setReuseLineageTraces(false); //disable lineage trace reuse boolean topLevelPF = OptimizerUtils.isTopLevelParFor(); RemoteParForJobReturn ret = RemoteParForSpark.runJob(_ID, program, clsMap, tasks, ec, brVars, _resultVars, _enableCPCaching, _numThreads, topLevelPF); @@ -913,6 +917,7 @@ public class ParForProgramBlock extends ForProgramBlock //lineage maintenance mergeLineage(ec, ret.getLineages()); + //LineageCacheConfig.setReuseLineageTraces(true); // TODO: remove duplicate lineage items in ec.getLineage() //consolidate results into global symbol table diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java index cdedfb9e45..9c8547f615 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java @@ -167,6 +167,7 @@ public class ExecutionContext { } /** + * * Get the i-th GPUContext * @param index index of the GPUContext * @return a valid GPUContext or null if the indexed GPUContext does not exist. @@ -924,6 +925,18 @@ public class ExecutionContext { throw new DMLRuntimeException("Lineage Trace unavailable."); return _lineage.getOrCreate(input); } + + public void replaceLineageItem(String varname, LineageItem li) { + if (!LineageCacheConfig.isLineageTraceReuse()) + return; + if( _lineage == null ) + throw new DMLRuntimeException("Lineage Trace unavailable."); + if (_lineage.get(varname) == null) + throw new DMLRuntimeException("Lineage item does not exist for "+varname); + //Passed lineage trace should be equivalent to the live lineage trace + //corresponding to varname. Replacing reduces memory and probing overheads. + _lineage.set(varname, li); + } private static String getNonExistingVarError(String varname) { return "Variable '" + varname + "' does not exist in the symbol table."; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java index b1c8248579..e2f4e5d270 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmChainSPInstruction.java @@ -20,9 +20,11 @@ package org.apache.sysds.runtime.instructions.spark; +import org.apache.commons.lang3.tuple.Pair; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; +import org.apache.sysds.common.Types; import org.apache.sysds.lops.MapMultChain; import org.apache.sysds.lops.MapMultChain.ChainType; import org.apache.sysds.runtime.DMLRuntimeException; @@ -32,12 +34,15 @@ import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast; import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils; +import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.lineage.LineageItemUtils; +import org.apache.sysds.runtime.lineage.LineageTraceable; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.operators.Operator; import scala.Tuple2; -public class MapmmChainSPInstruction extends SPInstruction { +public class MapmmChainSPInstruction extends SPInstruction implements LineageTraceable { private ChainType _chainType = null; private CPOperand _input1 = null; private CPOperand _input2 = null; @@ -116,7 +121,14 @@ public class MapmmChainSPInstruction extends SPInstruction { //this also includes implicit maintenance of matrix characteristics sec.setMatrixOutput(_output.getName(), out); } - + + @Override + public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { + CPOperand chainT = new CPOperand(_chainType.name(), Types.ValueType.INT64, Types.DataType.SCALAR, true); + return Pair.of(_output.getName(), new LineageItem(getOpcode(), + LineageItemUtils.getLineage(ec, _input1, _input2, _input3, chainT))); + } + /** * This function implements the chain type XtXv which requires just one broadcast and * no access to any indexes of matrix blocks. diff --git a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java index c233e55e23..866ebc8120 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/Lineage.java @@ -132,6 +132,7 @@ public class Lineage { public void initializeDedupBlock(ProgramBlock pb, ExecutionContext ec) { if( !(pb instanceof ForProgramBlock || pb instanceof WhileProgramBlock) ) throw new DMLRuntimeException("Invalid deduplication block: "+ pb.getClass().getSimpleName()); + LineageCacheConfig.setReuseLineageTraces(false); if (!_dedupBlocks.containsKey(pb)) { // valid only if doesn't contain a nested loop boolean valid = LineageDedupUtils.isValidDedupBlock(pb, false); diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java index a51c0ae9e3..1a3b12d7a9 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java @@ -108,7 +108,8 @@ public class LineageCache //try to reuse full or partial intermediates (CPU and FED only) for (MutablePair<LineageItem,LineageCacheEntry> item : liList) { if (LineageCacheConfig.getCacheType().isFullReuse()) - e = LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null; + //e = LineageCache.probe(item.getKey()) ? getIntern(item.getKey()) : null; + e = getIntern(item.getKey()); //avoid double probing (containsKey + get) //TODO need to also move execution of compensation plan out of here //(create lazily evaluated entry) if (e == null && LineageCacheConfig.getCacheType().isPartialReuse() @@ -162,6 +163,7 @@ public class LineageCache //Even not persisted, reuse the rdd locally for shuffle operations if (!LineageCacheConfig.isShuffleOp(inst)) return false; + ((SparkExecutionContext) ec).setRDDHandleForVariable(outName, rdd); break; case PERSISTEDRDD: @@ -184,6 +186,8 @@ public class LineageCache //Increment the live count for this pointer LineageGPUCacheEviction.incrementLiveCount(e.getGPUPointer()); } + //Replace the live lineage trace with the cached one (if not parfor, dedup) + ec.replaceLineageItem(outName, e._key); } maintainReuseStatistics(ec, inst, liList.get(0).getValue()); } @@ -444,6 +448,7 @@ public class LineageCache if (!p && DMLScript.STATISTICS && LineageCacheEviction._removelist.containsKey(key)) // The sought entry was in cache but removed later LineageCacheStatistics.incrementDelHits(); + return p; } @@ -949,9 +954,11 @@ public class LineageCache } private static LineageCacheEntry getIntern(LineageItem key) { - // This method is called only when entry is present either in cache or in local FS. LineageCacheEntry e = _cache.get(key); - if (e != null && e.getCacheStatus() != LineageCacheStatus.SPILLED) { + if (e == null) + return null; + + if (e.getCacheStatus() != LineageCacheStatus.SPILLED) { if (DMLScript.STATISTICS) // Increment hit count. LineageCacheStatistics.incrementMemHits(); @@ -1222,6 +1229,7 @@ public class LineageCache //TODO: Replace with generic type List<MutablePair<LineageItem, LineageCacheEntry>> liList = null; + //FIXME: Replace getLineageItem with get/getOrCreate to avoid creating a new LI object LineageItem instLI = (cinst != null) ? cinst.getLineageItem(ec).getValue() : (cfinst != null) ? cfinst.getLineageItem(ec).getValue() : (cspinst != null) ? cspinst.getLineageItem(ec).getValue() diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java index d0e32570b9..63863f7029 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java @@ -50,11 +50,11 @@ public class LineageCacheConfig private static final String[] OPCODES = new String[] { "tsmm", "ba+*", "*", "/", "+", "||", "nrow", "ncol", "round", "exp", "log", "rightIndex", "leftIndex", "groupedagg", "r'", "solve", "spoof", - "uamean", "max", "min", "ifelse", "-", "sqrt", ">", "uak+", "<=", + "uamean", "max", "min", "ifelse", "-", "sqrt", "<", ">", "uak+", "<=", "^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand", "replace", - "^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax", "qsort", + "^2", "*2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax", "qsort", "qpick", "transformapply", "uarmax", "n+", "-*", "castdtm", "lowertri", - "prefetch", "mapmm" + "prefetch", "mapmm", "contains", "mmchain", "mapmmchain", "+*" //TODO: Reuse everything. }; @@ -70,7 +70,7 @@ public class LineageCacheConfig // Relatively inexpensive instructions. private static final String[] PERSIST_OPCODES2 = new String[] { - "mapmm" + "mapmm," }; private static String[] REUSE_OPCODES = new String[] {}; @@ -104,6 +104,7 @@ public class LineageCacheConfig private static CachedItemTail _itemT = null; private static boolean _compilerAssistedRW = false; private static boolean _onlyEstimate = false; + private static boolean _reuseLineageTraces = true; //-------------DISK SPILLING RELATED CONFIGURATIONS--------------// @@ -368,6 +369,14 @@ public class LineageCacheConfig return _compilerAssistedRW; } + public static void setReuseLineageTraces(boolean reuseTrace) { + _reuseLineageTraces = reuseTrace; + } + + public static boolean isLineageTraceReuse() { + return _reuseLineageTraces; + } + public static void setCachePolicy(LineageCachePolicy policy) { // TODO: Automatic tuning of weights. switch(policy) { diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java index c7bbd6a00d..fee5b8a0dd 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java @@ -41,6 +41,7 @@ public class LineageCacheStatistics { private static final LongAdder _ctimeFSWrite = new LongAdder(); private static final LongAdder _ctimeSaved = new LongAdder(); private static final LongAdder _ctimeMissed = new LongAdder(); + private static final LongAdder _ctimeProbe = new LongAdder(); // Bellow entries are specific to gpu lineage cache private static final LongAdder _numHitsGpu = new LongAdder(); private static final LongAdder _numAsyncEvictGpu= new LongAdder(); @@ -70,6 +71,7 @@ public class LineageCacheStatistics { _ctimeFSWrite.reset(); _ctimeSaved.reset(); _ctimeMissed.reset(); + _ctimeProbe.reset(); _evtimeGpu.reset(); _numHitsGpu.reset(); _numAsyncEvictGpu.reset(); @@ -191,6 +193,10 @@ public class LineageCacheStatistics { _ctimeMissed.add(delta); } + public static void incrementProbeTime(long delta) { + _ctimeProbe.add(delta); + } + public static long getMultiLevelFnHits() { return _numHitsFunc.longValue(); } @@ -303,6 +309,8 @@ public class LineageCacheStatistics { sb.append(String.format("%.3f", ((double)_ctimeSaved.longValue())/1000000000)); //in sec sb.append("/"); sb.append(String.format("%.3f", ((double)_ctimeMissed.longValue())/1000000000)); //in sec + sb.append("/"); + sb.append(String.format("%.3f", ((double)_ctimeProbe.longValue())/1000000000)); //in sec return sb.toString(); } diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java index 311dae2a86..943f497937 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java @@ -25,6 +25,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Stack; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.util.UtilFunctions; @@ -269,8 +270,8 @@ public class LineageItem { Stack<LineageItem> s2 = new Stack<>(); s1.push(this); s2.push(that); - //boolean ret = false; boolean ret = true; + long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; while (!s1.empty() && !s2.empty()) { LineageItem li1 = s1.pop(); LineageItem li2 = s2.pop(); @@ -356,6 +357,8 @@ public class LineageItem { } li1.setVisited(); } + if (DMLScript.STATISTICS) //increment probing time + LineageCacheStatistics.incrementProbeTime(System.nanoTime() - t0); return ret; } diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java index 9d931f3ce3..6978507179 100644 --- a/src/main/java/org/apache/sysds/utils/Statistics.java +++ b/src/main/java/org/apache/sysds/utils/Statistics.java @@ -649,7 +649,7 @@ public class Statistics } sb.append("LinCache writes (Mem/FS/Del): \t" + LineageCacheStatistics.displayWtrites() + ".\n"); sb.append("LinCache FStimes (Rd/Wr): \t" + LineageCacheStatistics.displayFSTime() + " sec.\n"); - sb.append("LinCache Computetime (S/M): \t" + LineageCacheStatistics.displayComputeTime() + " sec.\n"); + sb.append("LinCache Computetime (S/M/P): \t" + LineageCacheStatistics.displayComputeTime() + " sec.\n"); sb.append("LinCache Rewrites: \t\t" + LineageCacheStatistics.displayRewrites() + ".\n"); }
