This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new f22e999 [SYSTEMDS-331,332] Fix robustness lineage cache (deadlocks,
correctness)
f22e999 is described below
commit f22e9991e2370dc30a1fed01c3142c27071da42c
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Apr 10 16:28:39 2020 +0200
[SYSTEMDS-331,332] Fix robustness lineage cache (deadlocks, correctness)
This patch fixes the robustness of lineage-based caching, especially in
multi-threaded parfor programs. This includes:
1) Deadlock prevention: With multi-level caching, the placeholders that
prevent concurrent computation of redundant intermediates led to
deadlocks because the following threads blocked inside the critical
region and thus any caching of the thread that was producing the
intermediate (via a complex DAG of operations) was blocked.
2) Deadlock wrong Data Types: With the introduction of scalar caching
each thread had to decide to either pull a scalar or matrix on the
placeholders. Since this decision was made based on the data item (which
might not be available yet in parfor) threads were blocking on the wrong
type and thus again producing deadlocks.
3) Correctness: The loop iteration variable of parfor was not integrated
yet with lineage tracing leading to incorrect reuse for different parfor
iterations that depended on the iteration variable.
Furthermore, this patch also cleans up an unnecessarily wide public API
of the lineage cache in order to facilitate a correct internal
implementation. However, there are still a number of remaining issues,
e.g., with the computation of compensation plans and probing logic.
---
docs/Tasks.txt | 2 +-
.../org/apache/sysds/parser/StatementBlock.java | 43 ++-
.../runtime/controlprogram/BasicProgramBlock.java | 16 +-
.../runtime/controlprogram/parfor/ParWorker.java | 42 ++-
.../instructions/cp/FunctionCallCPInstruction.java | 8 +-
.../apache/sysds/runtime/lineage/LineageCache.java | 388 +++++++++++----------
.../sysds/runtime/lineage/LineageCacheConfig.java | 30 ++
.../sysds/runtime/lineage/LineageRewriteReuse.java | 54 +--
.../functions/lineage/FunctionFullReuseTest.java | 42 ++-
.../functions/lineage/FunctionFullReuse6.dml | 37 ++
.../functions/lineage/FunctionFullReuse7.dml | 37 ++
11 files changed, 412 insertions(+), 287 deletions(-)
diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 42741da..d19672f 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -239,7 +239,7 @@ SYSTEMDS-320 Merge SystemDS into Apache SystemML
OK
SYSTEMDS-330 Lineage Tracing, Reuse and Integration
* 331 Cache and reuse scalar outputs (instruction and multi-level) OK
- * 332 Parfor integration with multi-level reuse
+ * 332 Parfor integration with multi-level reuse OK
* 333 Use exact execution time for cost based eviction
SYSTEMDS-340 Compiler Assisted Lineage Caching and Reuse
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index 2e87909..5991315 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -43,12 +43,12 @@ import org.apache.sysds.utils.MLContextProxy;
public class StatementBlock extends LiveVariableAnalysis implements ParseInfo
{
-
protected static final Log LOG =
LogFactory.getLog(StatementBlock.class.getName());
protected static IDSequence _seq = new IDSequence();
private static IDSequence _seqSBID = new IDSequence();
protected final long _ID;
-
+ protected final String _name;
+
protected DMLProgram _dmlProg;
protected ArrayList<Statement> _statements;
ArrayList<Hop> _hops = null;
@@ -62,6 +62,7 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
public StatementBlock() {
_ID = getNextSBID();
+ _name = "SB"+_ID;
_dmlProg = null;
_statements = new ArrayList<>();
_read = new VariableSet();
@@ -96,6 +97,10 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
public long getSBID() {
return _ID;
}
+
+ public String getName() {
+ return _name;
+ }
public void addStatement(Statement s) {
_statements.add(s);
@@ -399,8 +404,9 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
return inputs;
}
- public ArrayList<String> getOutputsofSB() {
- ArrayList<String> outputs = _liveOut != null && _updated !=
null ? new ArrayList<>() : null;
+ public ArrayList<String> getOutputNamesofSB() {
+ ArrayList<String> outputs = _liveOut != null
+ && _updated != null ? new ArrayList<>() : null;
if (_liveOut != null && _updated != null) {
for (String varName : _updated.getVariables().keySet())
{
if (_liveOut.containsVariable(varName))
@@ -409,6 +415,18 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
}
return outputs;
}
+
+ public ArrayList<DataIdentifier> getOutputsofSB() {
+ ArrayList<DataIdentifier> outputs = _liveOut != null
+ && _updated != null ? new ArrayList<>() : null;
+ if (_liveOut != null && _updated != null) {
+ for (String varName : _updated.getVariables().keySet())
{
+ if (_liveOut.containsVariable(varName))
+
outputs.add(_liveOut.getVariable(varName));
+ }
+ }
+ return outputs;
+ }
public static ArrayList<StatementBlock>
mergeStatementBlocks(ArrayList<StatementBlock> sb){
if (sb == null || sb.isEmpty())
@@ -683,29 +701,20 @@ public class StatementBlock extends LiveVariableAnalysis
implements ParseInfo
AssignmentStatement as =
(AssignmentStatement) current;
if ((as.getTargetList().size() == 1) &&
(as.getTargetList().get(0) != null)) {
raiseValidateError("Function '"
+ fcall.getName()
- + "' does not
return a value but is assigned to " + as.getTargetList().get(0),
- true);
+ + "' does not return a
value but is assigned to " + as.getTargetList().get(0), true);
}
}
- } else if (current instanceof MultiAssignmentStatement)
{
+ }
+ else if (current instanceof MultiAssignmentStatement) {
if (fstmt.getOutputParams().size() == 0) {
MultiAssignmentStatement mas =
(MultiAssignmentStatement) current;
raiseValidateError("Function '" +
fcall.getName()
- + "' does not return a
value but is assigned to " + mas.getTargetList(), true);
+ + "' does not return a value
but is assigned to " + mas.getTargetList(), true);
}
}
// handle returns by appending name mappings, but with
special handling of
// statements that contain function calls or
multi-return builtin expressions (but disabled)
-// Statement lastAdd =
newStatements.get(newStatements.size()-1);
-// if( isOutputBindingViaFunctionCall(lastAdd, prefix,
fstmt) && lastAdd instanceof AssignmentStatement )
-//
((AssignmentStatement)lastAdd).setTarget(((AssignmentStatement)current).getTarget());
-// else if ( isOutputBindingViaFunctionCall(lastAdd,
prefix, fstmt) && lastAdd instanceof MultiAssignmentStatement )
-// if( current instanceof MultiAssignmentStatement
)
-//
((MultiAssignmentStatement)lastAdd).setTargetList(((MultiAssignmentStatement)current).getTargetList());
-// else //correct for multi-assignment to
assignment transform
-//
newStatements.set(newStatements.size()-1,
createNewPartialMultiAssignment(lastAdd, current, prefix, fstmt));
-// else
appendOutputAssignments(current, prefix, fstmt,
newStatements);
}
return newStatements;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
index 1f52a75..5f44ac3 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/BasicProgramBlock.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.controlprogram;
import java.util.ArrayList;
+import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
@@ -29,7 +30,6 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
-import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
@@ -108,12 +108,10 @@ public class BasicProgramBlock extends ProgramBlock
//statement-block-level, lineage-based reuse
LineageItem[] liInputs = null;
- if (_sb != null
- && !ReuseCacheType.isNone()
- &&
LineageCacheConfig.getCacheType().isMultilevelReuse()) {
- String name = "SB" + _sb.getSBID();
+ if (_sb != null && LineageCacheConfig.isMultiLevelReuse()) {
liInputs =
LineageItemUtils.getLineageItemInputstoSB(_sb.getInputstoSB(), ec);
- if( LineageCache.reuse(_sb.getOutputsofSB(),
_sb.getOutputsofSB().size(), liInputs, name, ec) ) {
+ List<String> outNames = _sb.getOutputNamesofSB();
+ if( LineageCache.reuse(outNames, _sb.getOutputsofSB(),
outNames.size(), liInputs, _sb.getName(), ec) ) {
if( DMLScript.STATISTICS )
LineageCacheStatistics.incrementSBHits();
return;
@@ -124,9 +122,7 @@ public class BasicProgramBlock extends ProgramBlock
executeInstructions(tmp, ec);
//statement-block-level, lineage-based caching
- if (_sb != null && liInputs != null) {
- String name = "SB" + _sb.getSBID();
- LineageCache.putValue(_sb.getOutputsofSB(),
_sb.getOutputsofSB().size(), liInputs, name, ec);
- }
+ if (_sb != null && liInputs != null)
+ LineageCache.putValue(_sb.getOutputsofSB(), liInputs,
_sb.getName(), ec);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ParWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ParWorker.java
index 9f8fbb4..7b74ace 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ParWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ParWorker.java
@@ -24,6 +24,7 @@ import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.ParForStatementBlock.ResultVar;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
@@ -32,8 +33,10 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Stat;
import org.apache.sysds.runtime.controlprogram.parfor.stat.StatisticMonitor;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.IntObject;
+import org.apache.sysds.runtime.lineage.Lineage;
/**
* Super class for master/worker pattern implementations. Central place to
@@ -113,22 +116,20 @@ public abstract class ParWorker
protected void executeTask( Task task ) {
LOG.trace("EXECUTE PARFOR_WORKER ID="+_workerID+" for task
"+task.toCompactString());
- switch( task.getType() )
- {
+ switch( task.getType() ) {
case SET:
executeSetTask( task );
break;
case RANGE:
executeRangeTask( task );
- break;
+ break;
}
}
private void executeSetTask( Task task ) {
//monitoring start
Timing time1 = null, time2 = null;
- if( _monitor )
- {
+ if( _monitor ) {
time1 = new Timing(true);
time2 = new Timing(true);
}
@@ -143,6 +144,10 @@ public abstract class ParWorker
//set index values
_ec.setVariable(lVarName, indexVal);
+ if (DMLScript.LINEAGE) {
+ Lineage li = _ec.getLineage();
+ li.set(lVarName, li.getOrCreate(new
CPOperand(indexVal)));
+ }
// for each program block
for (ProgramBlock pb : _childBlocks)
@@ -157,8 +162,7 @@ public abstract class ParWorker
_numTasks++;
//monitoring end
- if( _monitor )
- {
+ if( _monitor ) {
StatisticMonitor.putPWStat(_workerID,
Stat.PARWRK_TASKSIZE, task.size());
StatisticMonitor.putPWStat(_workerID,
Stat.PARWRK_TASK_T, time2.stop());
}
@@ -167,10 +171,9 @@ public abstract class ParWorker
private void executeRangeTask( Task task ) {
//monitoring start
Timing time1 = null, time2 = null;
- if( _monitor )
- {
- time1 = new Timing(true);
- time2 = new Timing(true);
+ if( _monitor ) {
+ time1 = new Timing(true);
+ time2 = new Timing(true);
}
//core execution
@@ -183,28 +186,29 @@ public abstract class ParWorker
for( long i=lFrom; i<=lTo; i+=lIncr )
{
//set index values
- _ec.setVariable(lVarName, new IntObject(i));
+ IntObject indexVal = new IntObject(i);
+ _ec.setVariable(lVarName, indexVal);
+ if (DMLScript.LINEAGE) {
+ Lineage li = _ec.getLineage();
+ li.set(lVarName, li.getOrCreate(new
CPOperand(indexVal)));
+ }
// for each program block
for (ProgramBlock pb : _childBlocks)
pb.execute(_ec);
-
+
_numIters++;
if( _monitor )
- StatisticMonitor.putPWStat(_workerID,
Stat.PARWRK_ITER_T, time1.stop());
+ StatisticMonitor.putPWStat(_workerID,
Stat.PARWRK_ITER_T, time1.stop());
}
_numTasks++;
//monitoring end
- if( _monitor )
- {
+ if( _monitor ) {
StatisticMonitor.putPWStat(_workerID,
Stat.PARWRK_TASKSIZE, task.size());
StatisticMonitor.putPWStat(_workerID,
Stat.PARWRK_TASK_T, time2.stop());
}
}
-
}
-
-
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 5a24ad8..e605a55 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -39,6 +39,7 @@ import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.Lineage;
import org.apache.sysds.runtime.lineage.LineageCache;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
@@ -115,7 +116,7 @@ public class FunctionCallCPInstruction extends
CPInstruction {
}
// check if function outputs can be reused from cache
- LineageItem[] liInputs = DMLScript.LINEAGE ?
+ LineageItem[] liInputs = DMLScript.LINEAGE &&
LineageCacheConfig.isMultiLevelReuse() ?
LineageItemUtils.getLineage(ec, _boundInputs) : null;
if( reuseFunctionOutputs(liInputs, fpb, ec) )
return; //only if all the outputs are found in cache
@@ -224,7 +225,8 @@ public class FunctionCallCPInstruction extends
CPInstruction {
}
//update lineage cache with the functions outputs
- LineageCache.putValue(_boundOutputNames, numOutputs, liInputs,
_functionName, ec);
+ if( DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse()
)
+ LineageCache.putValue(fpb.getOutputParams(), liInputs,
_functionName, ec);
}
@Override
@@ -261,7 +263,7 @@ public class FunctionCallCPInstruction extends
CPInstruction {
private boolean reuseFunctionOutputs(LineageItem[] liInputs,
FunctionProgramBlock fpb, ExecutionContext ec) {
int numOutputs = Math.min(_boundOutputNames.size(),
fpb.getOutputParams().size());
- boolean reuse = LineageCache.reuse(_boundOutputNames,
numOutputs, liInputs, _functionName, ec);
+ boolean reuse = LineageCache.reuse(_boundOutputNames,
fpb.getOutputParams(), numOutputs, liInputs, _functionName, ec);
if (reuse && DMLScript.STATISTICS) {
//decrement the call count for this function
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 789b9f7..2741b70 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -20,10 +20,12 @@
package org.apache.sysds.runtime.lineage;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.cost.CostEstimatorStaticRuntime;
import org.apache.sysds.lops.MMTSJ.MMTSJType;
+import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -51,7 +53,8 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
-public class LineageCache {
+public class LineageCache
+{
private static final Map<LineageItem, Entry> _cache = new HashMap<>();
private static final Map<LineageItem, SpilledItem> _spillList = new
HashMap<>();
private static final HashSet<LineageItem> _removelist = new HashSet<>();
@@ -67,7 +70,18 @@ public class LineageCache {
CACHE_LIMIT = (long)(CACHE_FRAC * maxMem);
}
- //--------------------- CACHE LOGIC METHODS ----------------------
+ // Cache Synchronization Approach:
+ // The central static cache is only synchronized in a fine-grained
manner
+ // for short get, put, or remove calls or during eviction. All
blocking of
+ // threads for computing the values of placeholders is done on the
individual
+ // entry objects which reduces contention and prevents deadlocks in
case of
+ // function/statement block placeholders which computation itself
might be
+ // a complex workflow of operations that accesses the cache as well.
+
+
+ ///////////////////////////////////////
+ // Public Cache API (keep it narrow) //
+ ///////////////////////////////////////
public static boolean reuse(Instruction inst, ExecutionContext ec) {
if (ReuseCacheType.isNone())
@@ -76,77 +90,85 @@ public class LineageCache {
boolean reuse = false;
//NOTE: the check for computation CP instructions ensures that
the output
// will always fit in memory and hence can be pinned
unconditionally
- if (inst instanceof ComputationCPInstruction &&
LineageCache.isReusable(inst, ec)) {
- LineageItem item = ((ComputationCPInstruction)
inst).getLineageItems(ec)[0];
+ if (LineageCacheConfig.isReusable(inst, ec)) {
+ ComputationCPInstruction cinst =
(ComputationCPInstruction) inst;
+ LineageItem item = cinst.getLineageItems(ec)[0];
+ //atomic try reuse full/partial and set placeholder,
without
+ //obtaining value to avoid blocking in critical section
+ Entry e = null;
synchronized( _cache ) {
//try to reuse full or partial intermediates
if
(LineageCacheConfig.getCacheType().isFullReuse())
- reuse = fullReuse(item,
(ComputationCPInstruction)inst, ec);
- if
(LineageCacheConfig.getCacheType().isPartialReuse())
- reuse |=
LineageRewriteReuse.executeRewrites(inst, ec);
-
- if (reuse && DMLScript.STATISTICS)
-
LineageCacheStatistics.incrementInstHits();
+ e = LineageCache.probe(item) ?
getIntern(item) : null;
+ //TODO need to also move execution of
compensation plan out of here
+ //(create lazily evaluated entry)
+ if (e == null &&
LineageCacheConfig.getCacheType().isPartialReuse())
+ if(
LineageRewriteReuse.executeRewrites(inst, ec) )
+ e = getIntern(item);
+ reuse = (e != null);
//create a placeholder if no reuse to avoid
redundancy
//(e.g., concurrent threads that try to start
the computation)
- if(!reuse && isMarkedForCaching(inst, ec))
- putIntern(item, null, null, 0);
+ if(!reuse && isMarkedForCaching(inst, ec)) {
+ putIntern(item,
cinst.output.getDataType(), null, null, 0);
+ }
+ }
+
+ if( reuse ) { //reuse
+ //put reuse value into symbol table (w/
blocking on placeholders)
+ if (e.isMatrixValue())
+
ec.setMatrixOutput(cinst.output.getName(), e.getMBValue());
+ else
+
ec.setScalarOutput(cinst.output.getName(), e.getSOValue());
+ if (DMLScript.STATISTICS)
+
LineageCacheStatistics.incrementInstHits();
+ reuse = true;
}
}
return reuse;
}
- public static Entry reuse(LineageItem item) {
- if (ReuseCacheType.isNone())
- return null;
-
- Entry e = null;
- synchronized( _cache ) {
- if (LineageCache.probe(item))
- e = LineageCache.get(item);
- else
- //create a placeholder if no reuse to avoid
redundancy
- //(e.g., concurrent threads that try to start
the computation)
- putIntern(item, null, null, 0);
- //FIXME: parfor - every thread gets different
function names
- }
- return e;
- }
-
- public static boolean reuse(List<String> outputs, int numOutputs,
LineageItem[] liInputs, String name, ExecutionContext ec)
+ public static boolean reuse(List<String> outNames, List<DataIdentifier>
outParams, int numOutputs, LineageItem[] liInputs, String name,
ExecutionContext ec)
{
- if( ReuseCacheType.isNone() ||
!LineageCacheConfig.getCacheType().isMultilevelReuse())
+ if( !LineageCacheConfig.isMultiLevelReuse())
return false;
-
- boolean reuse = (numOutputs != 0);
+
+ boolean reuse = (outParams.size() != 0);
HashMap<String, Data> funcOutputs = new HashMap<>();
HashMap<String, LineageItem> funcLIs = new HashMap<>();
for (int i=0; i<numOutputs; i++) {
String opcode = name + String.valueOf(i+1);
- LineageItem li = new LineageItem(outputs.get(i),
opcode, liInputs);
- Entry cachedValue = LineageCache.reuse(li);
+ LineageItem li = new LineageItem(outNames.get(i),
opcode, liInputs);
+ Entry e = null;
+ synchronized( _cache ) {
+ if (LineageCache.probe(li))
+ e = LineageCache.getIntern(li);
+ else
+ //create a placeholder if no reuse to
avoid redundancy
+ //(e.g., concurrent threads that try to
start the computation)
+ putIntern(li,
outParams.get(i).getDataType(), null, null, 0);
+ //FIXME: parfor - every thread gets
different function names
+ }
//TODO: handling of recursive calls
- if (cachedValue != null && !cachedValue.isNullVal()) {
- String boundVarName = outputs.get(i);
+ if (e != null && !e.isNullVal()) {
+ String boundVarName = outNames.get(i);
Data boundValue = null;
//convert to matrix object
- if (cachedValue.isMatrixValue()) {
- MetaDataFormat md = new
MetaDataFormat(cachedValue.getMBValue().getDataCharacteristics(),
-
OutputInfo.BinaryCellOutputInfo, InputInfo.BinaryCellInputInfo);
+ if (e.isMatrixValue()) {
+ MetaDataFormat md = new
MetaDataFormat(e.getMBValue().getDataCharacteristics(),
+
OutputInfo.BinaryCellOutputInfo, InputInfo.BinaryCellInputInfo);
boundValue = new
MatrixObject(ValueType.FP64, boundVarName, md);
-
((MatrixObject)boundValue).acquireModify(cachedValue.getMBValue());
+
((MatrixObject)boundValue).acquireModify(e.getMBValue());
((MatrixObject)boundValue).release();
}
else
- boundValue = cachedValue.getSOValue();
+ boundValue = e.getSOValue();
funcOutputs.put(boundVarName, boundValue);
-
- LineageItem orig = _cache.get(li)._origItem;
//FIXME: synchronize
+ LineageItem orig = e._origItem;
funcLIs.put(boundVarName, orig);
}
else {
@@ -169,17 +191,37 @@ public class LineageCache {
//map original lineage items return to the calling site
funcLIs.forEach((var, li) -> ec.getLineage().set(var,
li));
}
+
return reuse;
}
+ public static boolean probe(LineageItem key) {
+ //TODO problematic as after probe the matrix might be kicked
out of cache
+ boolean p = (_cache.containsKey(key) ||
_spillList.containsKey(key));
+ if (!p && DMLScript.STATISTICS && _removelist.contains(key))
+ // The sought entry was in cache but removed later
+ LineageCacheStatistics.incrementDelHits();
+ return p;
+ }
+
+ public static MatrixBlock getMatrix(LineageItem key) {
+ Entry e = null;
+ synchronized( _cache ) {
+ e = getIntern(key);
+ }
+ return e.getMBValue();
+ }
+
//NOTE: safe to pin the object in memory as coming from CPInstruction
- public static void put(Instruction inst, ExecutionContext ec) {
- if (inst instanceof ComputationCPInstruction &&
isReusable(inst, ec) ) {
+ //TODO why do we need both of these public put methods
+ public static void putMatrix(Instruction inst, ExecutionContext ec) {
+ if (LineageCacheConfig.isReusable(inst, ec) ) {
LineageItem item = ((LineageTraceable)
inst).getLineageItems(ec)[0];
//This method is called only to put matrix value
MatrixObject mo =
ec.getMatrixObject(((ComputationCPInstruction) inst).output);
synchronized( _cache ) {
- putIntern(item, mo.acquireReadAndRelease(),
null, getRecomputeEstimate(inst, ec));
+ putIntern(item, DataType.MATRIX,
mo.acquireReadAndRelease(),
+ null, getRecomputeEstimate(inst, ec));
}
}
}
@@ -187,18 +229,18 @@ public class LineageCache {
public static void putValue(Instruction inst, ExecutionContext ec) {
if (ReuseCacheType.isNone())
return;
- if (inst instanceof ComputationCPInstruction &&
isReusable(inst, ec) ) {
- if (!isMarkedForCaching(inst, ec)) return;
+ if (LineageCacheConfig.isReusable(inst, ec) ) {
+ //if (!isMarkedForCaching(inst, ec)) return;
LineageItem item = ((LineageTraceable)
inst).getLineageItems(ec)[0];
- //MatrixObject mo =
ec.getMatrixObject(((ComputationCPInstruction) inst).output);
Data data = ec.getVariable(((ComputationCPInstruction)
inst).output);
- MatrixObject mo = data instanceof MatrixObject ?
(MatrixObject)data : null;
- ScalarObject so = data instanceof ScalarObject ?
(ScalarObject)data : null;
- MatrixBlock Mval = mo != null ?
mo.acquireReadAndRelease() : null;
- _cache.get(item).setValue(Mval, so,
getRecomputeEstimate(inst, ec)); //outside sync to prevent deadlocks
- long size = _cache.get(item).getSize();
-
+ double cest = getRecomputeEstimate(inst, ec);
synchronized( _cache ) {
+ if( data instanceof MatrixObject )
+
_cache.get(item).setValue(((MatrixObject)data).acquireReadAndRelease(), cest);
+ else
+
_cache.get(item).setValue((ScalarObject)data, cest);
+ long size = _cache.get(item).getSize();
+
if( !isBelowThreshold(size) )
makeSpace(size);
updateSize(size, true);
@@ -206,42 +248,17 @@ public class LineageCache {
}
}
- public static void putValue(LineageItem item, LineageItem probeItem) {
- if (ReuseCacheType.isNone())
- return;
- if (LineageCache.probe(probeItem)) {
- Entry oe = LineageCache.get(probeItem);
- Entry e = _cache.get(item);
- //TODO: compute estimate for function
- if (oe.isMatrixValue())
- e.setValue(oe.getMBValue(), null, 0);
- else
- e.setValue(null, oe.getSOValue(), 0);
- e._origItem = probeItem;
-
- long size = oe.getSize();
- synchronized( _cache ) {
- if(!isBelowThreshold(size))
- makeSpace(size);
- updateSize(size, true);
- }
- }
- else
- removeEntry(item); //remove the placeholder
-
- }
-
- public static void putValue(List<String> outputs, int numOutputs,
LineageItem[] liInputs, String name, ExecutionContext ec)
+ public static void putValue(List<DataIdentifier> outputs, LineageItem[]
liInputs, String name, ExecutionContext ec)
{
- if( ReuseCacheType.isNone() ||
!LineageCacheConfig.getCacheType().isMultilevelReuse())
+ if( !LineageCacheConfig.isMultiLevelReuse())
return;
HashMap<LineageItem, LineageItem> FuncLIMap = new HashMap<>();
boolean AllOutputsCacheable = true;
- for (int i=0; i<numOutputs; i++) {
+ for (int i=0; i<outputs.size(); i++) {
String opcode = name + String.valueOf(i+1);
- LineageItem li = new LineageItem(outputs.get(i),
opcode, liInputs);
- String boundVarName = outputs.get(i);
+ LineageItem li = new
LineageItem(outputs.get(i).getName(), opcode, liInputs);
+ String boundVarName = outputs.get(i).getName();
LineageItem boundLI = ec.getLineage().get(boundVarName);
if (boundLI != null)
boundLI.resetVisitStatus();
@@ -254,23 +271,42 @@ public class LineageCache {
}
//cache either all the outputs, or none.
- if(AllOutputsCacheable)
- FuncLIMap.forEach((Li, boundLI) ->
LineageCache.putValue(Li, boundLI));
- else
- //remove all the placeholders
- FuncLIMap.forEach((Li, boundLI) ->
LineageCache.removeEntry(Li));
+ synchronized( _cache ) {
+ //move or remove placeholders
+ if(AllOutputsCacheable)
+ FuncLIMap.forEach((Li, boundLI) -> mvIntern(Li,
boundLI));
+ else
+ FuncLIMap.forEach((Li, boundLI) ->
removeEntry(Li));
+ }
return;
}
- private static void putIntern(LineageItem key, MatrixBlock Mval,
ScalarObject Sval, double compcost) {
+ public static void resetCache() {
+ synchronized( _cache ) {
+ _cache.clear();
+ _spillList.clear();
+ _head = null;
+ _end = null;
+ // reset cache size, otherwise the cache clear leads to
unusable
+ // space which means evictions could run into endless
loops
+ _cachesize = 0;
+ if (DMLScript.STATISTICS)
+ _removelist.clear();
+ }
+ }
+
+ /////////////////////////////////////////
+ // Internal Cache Logic Implementation //
+ /////////////////////////////////////////
+
+ private static void putIntern(LineageItem key, DataType dt, MatrixBlock
Mval, ScalarObject Sval, double compcost) {
if (_cache.containsKey(key))
//can come here if reuse_partial option is enabled
- return;
- //throw new DMLRuntimeException("Redundant lineage
caching detected: "+inst);
+ return;
// Create a new entry.
- Entry newItem = new Entry(key, Mval, Sval, compcost);
+ Entry newItem = new Entry(key, dt, Mval, Sval, compcost);
// Make space by removing or spilling LRU entries.
if( Mval != null || Sval != null ) {
@@ -290,40 +326,7 @@ public class LineageCache {
LineageCacheStatistics.incrementMemWrites();
}
- protected static boolean probe(LineageItem key) {
- boolean p = (_cache.containsKey(key) ||
_spillList.containsKey(key));
- if (!p && DMLScript.STATISTICS && _removelist.contains(key))
- // The sought entry was in cache but removed later
- LineageCacheStatistics.incrementDelHits();
- return p;
- }
-
- public static void resetCache() {
- _cache.clear();
- _spillList.clear();
- _head = null;
- _end = null;
- // reset cache size, otherwise the cache clear leads to
unusable
- // space which means evictions could run into endless loops
- _cachesize = 0;
- if (DMLScript.STATISTICS)
- _removelist.clear();
- }
-
-
- private static boolean fullReuse (LineageItem item,
ComputationCPInstruction inst, ExecutionContext ec) {
- if (LineageCache.probe(item)) {
- Entry e = LineageCache.get(item);
- if (e.isMatrixValue())
- ec.setMatrixOutput(inst.output.getName(),
e.getMBValue());
- else
- ec.setScalarOutput(inst.output.getName(),
e.getSOValue());
- return true;
- }
- return false;
- }
-
- protected static Entry get(LineageItem key) {
+ private static Entry getIntern(LineageItem key) {
// This method is called only when entry is present either in
cache or in local FS.
if (_cache.containsKey(key)) {
// Read and put the entry at head.
@@ -337,44 +340,39 @@ public class LineageCache {
else
return readFromLocalFS(key);
}
+
- public static boolean isReusable (Instruction inst, ExecutionContext
ec) {
- // TODO: Move this to the new class LineageCacheConfig and
extend
- return inst.getOpcode().equalsIgnoreCase("tsmm")
- || inst.getOpcode().equalsIgnoreCase("ba+*")
- || inst.getOpcode().equalsIgnoreCase("*")
- || inst.getOpcode().equalsIgnoreCase("/")
- || inst.getOpcode().equalsIgnoreCase("+")
- || inst.getOpcode().equalsIgnoreCase("nrow")
- || inst.getOpcode().equalsIgnoreCase("ncol")
- ||
inst.getOpcode().equalsIgnoreCase("rightIndex")
- ||
inst.getOpcode().equalsIgnoreCase("leftIndex")
- ||
inst.getOpcode().equalsIgnoreCase("groupedagg")
- || inst.getOpcode().equalsIgnoreCase("r'")
- || (inst.getOpcode().equalsIgnoreCase("append")
&& isVectorAppend(inst, ec))
- || inst.getOpcode().equalsIgnoreCase("solve")
- || inst.getOpcode().contains("spoof");
- }
-
- private static boolean isVectorAppend(Instruction inst,
ExecutionContext ec) {
- ComputationCPInstruction cpinst = (ComputationCPInstruction)
inst;
- if( !cpinst.input1.isMatrix() || !cpinst.input2.isMatrix() )
- return false;
- long c1 = ec.getMatrixObject(cpinst.input1).getNumColumns();
- long c2 = ec.getMatrixObject(cpinst.input2).getNumColumns();
- return(c1 == 1 || c2 == 1);
+ private static void mvIntern(LineageItem item, LineageItem probeItem) {
+ if (ReuseCacheType.isNone())
+ return;
+ if (LineageCache.probe(probeItem)) {
+ Entry oe = getIntern(probeItem);
+ Entry e = _cache.get(item);
+ //TODO: compute estimate for function
+ if (oe.isMatrixValue())
+ e.setValue(oe.getMBValue(), 0);
+ else
+ e.setValue(oe.getSOValue(), 0);
+ e._origItem = probeItem;
+
+ long size = oe.getSize();
+ if(!isBelowThreshold(size))
+ makeSpace(size);
+ updateSize(size, true);
+ }
+ else
+ removeEntry(item); //remove the placeholder
}
- public static boolean isMarkedForCaching (Instruction inst,
ExecutionContext ec) {
+ private static boolean isMarkedForCaching (Instruction inst,
ExecutionContext ec) {
if (!LineageCacheConfig.getCompAssRW())
return true;
if (((ComputationCPInstruction)inst).output.isMatrix()) {
MatrixObject mo =
ec.getMatrixObject(((ComputationCPInstruction)inst).output);
//limit this to full reuse as partial reuse is
applicable even for loop dependent operation
- boolean marked = (LineageCacheConfig.getCacheType() ==
ReuseCacheType.REUSE_FULL
- && !mo.isMarked()) ? false : true;
- return marked;
+ return !(LineageCacheConfig.getCacheType() ==
ReuseCacheType.REUSE_FULL
+ && !mo.isMarked());
}
else
return true;
@@ -397,7 +395,6 @@ public class LineageCache {
continue;
}
- double reduction = _cache.get(_end._key).getSize();
if (_cache.get(_end._key).isMatrixValue()) { //spill
matrix blocks only
if (_cache.get(_end._key)._compEst >
getDiskSpillEstimate()
&&
LineageCacheConfig.isSetSpill())
@@ -410,8 +407,8 @@ public class LineageCache {
setEnd2Head(_end);
continue;
}
- removeEntry(reduction);
- }
+ removeLastEntry();
+ }
}
private static void updateSize(long space, boolean addspace) {
@@ -617,7 +614,7 @@ public class LineageCache {
}
// Restore to cache
LocalFileUtils.deleteFileIfExists(_spillList.get(key)._outfile,
true);
- putIntern(key, mb, null, _spillList.get(key)._compEst);
+ putIntern(key, DataType.MATRIX, mb, null,
_spillList.get(key)._compEst);
_spillList.remove(key);
if (DMLScript.STATISTICS) {
long t1 = System.nanoTime();
@@ -627,7 +624,30 @@ public class LineageCache {
return _cache.get(key);
}
- //------------------ LINKEDLIST MAINTENANCE METHODS -------------------
+ ////////////////////////////////////////////
+ // Cache Maintenance and Lookup Functions //
+ ////////////////////////////////////////////
+
+ private static void removeLastEntry() {
+ if (DMLScript.STATISTICS)
+ _removelist.add(_end._key);
+ Entry e = _cache.remove(_end._key);
+ _cachesize -= e.getSize();
+ delete(_end);
+ }
+
+ private static void removeEntry(LineageItem key) {
+ // Remove the entry for key
+ if (!_cache.containsKey(key))
+ return;
+ delete(_cache.get(key));
+ _cache.remove(key);
+ }
+
+ private static void setEnd2Head(Entry entry) {
+ delete(entry);
+ setHead(entry);
+ }
private static void delete(Entry entry) {
if (entry._prev != null)
@@ -650,29 +670,13 @@ public class LineageCache {
_end = _head;
}
- private static void setEnd2Head(Entry entry) {
- delete(entry);
- setHead(entry);
- }
-
- private static void removeEntry(double space) {
- if (DMLScript.STATISTICS)
- _removelist.add(_end._key);
- _cache.remove(_end._key);
- _cachesize -= space;
- delete(_end);
- }
-
- public static void removeEntry(LineageItem key) {
- // Remove the entry for key
- if (!_cache.containsKey(key))
- return;
- delete(_cache.get(key));
- _cache.remove(key);
- }
+ ////////////////////////////////////
+ // Internal Cache Data Structures //
+ ////////////////////////////////////
- static class Entry {
+ private static class Entry {
private final LineageItem _key;
+ private final DataType _dt;
private MatrixBlock _MBval;
private ScalarObject _SOval;
double _compEst;
@@ -680,8 +684,9 @@ public class LineageCache {
private Entry _next;
private LineageItem _origItem;
- public Entry(LineageItem key, MatrixBlock Mval, ScalarObject
Sval, double computecost) {
+ public Entry(LineageItem key, DataType dt, MatrixBlock Mval,
ScalarObject Sval, double computecost) {
_key = key;
+ _dt = dt;
_MBval = Mval;
_SOval = Sval;
_compEst = computecost;
@@ -725,19 +730,20 @@ public class LineageCache {
}
public boolean isMatrixValue() {
- return(_MBval != null);
+ return _dt.isMatrix();
}
public synchronized void setValue(MatrixBlock val, double
compEst) {
_MBval = val;
_compEst = compEst;
+ //resume all threads waiting for val
notifyAll();
}
- public synchronized void setValue(MatrixBlock mval,
ScalarObject so, double compEst) {
- _MBval = mval;
- _SOval = so;
+ public synchronized void setValue(ScalarObject val, double
compEst) {
+ _SOval = val;
_compEst = compEst;
+ //resume all threads waiting for val
notifyAll();
}
}
@@ -747,8 +753,8 @@ public class LineageCache {
double _compEst;
public SpilledItem(String outfile, double computecost) {
- this._outfile = outfile;
- this._compEst = computecost;
+ _outfile = outfile;
+ _compEst = computecost;
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index e4ce09b..75a305a 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -19,11 +19,21 @@
package org.apache.sysds.runtime.lineage;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
+
import java.util.ArrayList;
public class LineageCacheConfig {
+ private static final String[] REUSE_OPCODES = new String[] {
+ "tmm", "ba+*", "*", "/", "+", "nrow", "ncol",
+ "rightIndex", "leftIndex", "groupedagg", "r'", "solve", "spoof"
+ };
+
public enum ReuseCacheType {
REUSE_FULL,
REUSE_PARTIAL,
@@ -69,6 +79,21 @@ public class LineageCacheConfig {
setSpill(false); //disable spilling of cache entries to disk
}
+ public static boolean isReusable (Instruction inst, ExecutionContext
ec) {
+ return inst instanceof ComputationCPInstruction
+ && (ArrayUtils.contains(REUSE_OPCODES, inst.getOpcode())
+ || (inst.getOpcode().equals("append") &&
isVectorAppend(inst, ec)));
+ }
+
+ private static boolean isVectorAppend(Instruction inst,
ExecutionContext ec) {
+ ComputationCPInstruction cpinst = (ComputationCPInstruction)
inst;
+ if( !cpinst.input1.isMatrix() || !cpinst.input2.isMatrix() )
+ return false;
+ long c1 = ec.getMatrixObject(cpinst.input1).getNumColumns();
+ long c2 = ec.getMatrixObject(cpinst.input2).getNumColumns();
+ return(c1 == 1 || c2 == 1);
+ }
+
public static void setConfigTsmmCbind(ReuseCacheType ct) {
_cacheType = ct;
_itemH = CachedItemHead.TSMM;
@@ -110,6 +135,11 @@ public class LineageCacheConfig {
public static ReuseCacheType getCacheType() {
return _cacheType;
}
+
+ public static boolean isMultiLevelReuse() {
+ return !ReuseCacheType.isNone()
+ && _cacheType.isMultilevelReuse();
+ }
public static CachedItemHead getCachedItemHead() {
return _itemH;
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
index fb5a21f..f1b8c58 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRewriteReuse.java
@@ -99,7 +99,7 @@ public class LineageRewriteReuse
ec.setVariable(((ComputationCPInstruction)curr).output.getName(),
lrwec.getVariable(LR_VAR));
//put the result into the cache
- LineageCache.put(curr, ec);
+ LineageCache.putMatrix(curr, ec);
DMLScript.EXPLAIN = et; //TODO can't change this here
//cleanup execution context
@@ -529,7 +529,7 @@ public class LineageRewriteReuse
private static boolean isTsmmCbind(Instruction curr, ExecutionContext
ec, Map<String, MatrixBlock> inCache)
{
- if (!LineageCache.isReusable(curr, ec)) {
+ if (!LineageCacheConfig.isReusable(curr, ec)) {
return false;
}
@@ -543,10 +543,10 @@ public class LineageRewriteReuse
LineageItem input1 = source.getInputs()[0];
LineageItem tmp = new LineageItem("toProbe",
curr.getOpcode(), new LineageItem[] {input1});
if (LineageCache.probe(tmp))
- inCache.put("lastMatrix",
LineageCache.get(tmp).getMBValue());
+ inCache.put("lastMatrix",
LineageCache.getMatrix(tmp));
// look for the appended column in cache
if (LineageCache.probe(source.getInputs()[1]))
- inCache.put("deltaX",
LineageCache.get(source.getInputs()[1]).getMBValue());
+ inCache.put("deltaX",
LineageCache.getMatrix(source.getInputs()[1]));
}
// return true only if the last tsmm is found
return inCache.containsKey("lastMatrix") ? true : false;
@@ -554,7 +554,7 @@ public class LineageRewriteReuse
private static boolean isTsmmRbind(Instruction curr, ExecutionContext
ec, Map<String, MatrixBlock> inCache)
{
- if (!LineageCache.isReusable(curr, ec))
+ if (!LineageCacheConfig.isReusable(curr, ec))
return false;
// If the input to tsmm came from rbind, look for both the
inputs in cache.
@@ -566,10 +566,10 @@ public class LineageRewriteReuse
LineageItem input1 = source.getInputs()[0];
LineageItem tmp = new LineageItem("toProbe",
curr.getOpcode(), new LineageItem[] {input1});
if (LineageCache.probe(tmp))
- inCache.put("lastMatrix",
LineageCache.get(tmp).getMBValue());
+ inCache.put("lastMatrix",
LineageCache.getMatrix(tmp));
// look for the appended column in cache
if (LineageCache.probe(source.getInputs()[1]))
- inCache.put("deltaX",
LineageCache.get(source.getInputs()[1]).getMBValue());
+ inCache.put("deltaX",
LineageCache.getMatrix(source.getInputs()[1]));
}
// return true only if the last tsmm is found
return inCache.containsKey("lastMatrix") ? true : false;
@@ -577,7 +577,7 @@ public class LineageRewriteReuse
private static boolean isTsmm2Cbind (Instruction curr, ExecutionContext
ec, Map<String, MatrixBlock> inCache)
{
- if (!LineageCache.isReusable(curr, ec))
+ if (!LineageCacheConfig.isReusable(curr, ec))
return false;
//TODO: support nary cbind
@@ -593,10 +593,10 @@ public class LineageRewriteReuse
LineageItem tmp = new
LineageItem("comb", "cbind", new LineageItem[] {L2appin1,
source.getInputs()[1]});
LineageItem toProbe = new
LineageItem("toProbe", curr.getOpcode(), new LineageItem[] {tmp});
if (LineageCache.probe(toProbe))
- inCache.put("lastMatrix",
LineageCache.get(toProbe).getMBValue());
+ inCache.put("lastMatrix",
LineageCache.getMatrix(toProbe));
// look for the appended column in cache
if
(LineageCache.probe(input.getInputs()[1]))
- inCache.put("deltaX",
LineageCache.get(input.getInputs()[1]).getMBValue());
+ inCache.put("deltaX",
LineageCache.getMatrix(input.getInputs()[1]));
}
}
// return true only if the last tsmm is found
@@ -605,7 +605,7 @@ public class LineageRewriteReuse
private static boolean isMatMulRbindLeft(Instruction curr,
ExecutionContext ec, Map<String, MatrixBlock> inCache)
{
- if (!LineageCache.isReusable(curr, ec))
+ if (!LineageCacheConfig.isReusable(curr, ec))
return false;
// If the left input to ba+* came from rbind, look for both the
inputs in cache.
@@ -618,10 +618,10 @@ public class LineageRewriteReuse
// create ba+* lineage on top of the input of
last append
LineageItem tmp = new LineageItem("toProbe",
curr.getOpcode(), new LineageItem[] {leftSource, right});
if (LineageCache.probe(tmp))
- inCache.put("lastMatrix",
LineageCache.get(tmp).getMBValue());
+ inCache.put("lastMatrix",
LineageCache.getMatrix(tmp));
// look for the appended column in cache
if (LineageCache.probe(left.getInputs()[1]))
- inCache.put("deltaX",
LineageCache.get(left.getInputs()[1]).getMBValue());
+ inCache.put("deltaX",
LineageCache.getMatrix(left.getInputs()[1]));
}
}
// return true only if the last tsmm is found
@@ -630,7 +630,7 @@ public class LineageRewriteReuse
private static boolean isMatMulCbindRight(Instruction curr,
ExecutionContext ec, Map<String, MatrixBlock> inCache)
{
- if (!LineageCache.isReusable(curr, ec))
+ if (!LineageCacheConfig.isReusable(curr, ec))
return false;
// If the right input to ba+* came from cbind, look for both
the inputs in cache.
@@ -643,10 +643,10 @@ public class LineageRewriteReuse
// create ba+* lineage on top of the input of
last append
LineageItem tmp = new LineageItem("toProbe",
curr.getOpcode(), new LineageItem[] {left, rightSource});
if (LineageCache.probe(tmp))
- inCache.put("lastMatrix",
LineageCache.get(tmp).getMBValue());
+ inCache.put("lastMatrix",
LineageCache.getMatrix(tmp));
// look for the appended column in cache
if (LineageCache.probe(right.getInputs()[1]))
- inCache.put("deltaY",
LineageCache.get(right.getInputs()[1]).getMBValue());
+ inCache.put("deltaY",
LineageCache.getMatrix(right.getInputs()[1]));
}
}
return inCache.containsKey("lastMatrix") ? true : false;
@@ -654,7 +654,7 @@ public class LineageRewriteReuse
private static boolean isElementMulRbind(Instruction curr,
ExecutionContext ec, Map<String, MatrixBlock> inCache)
{
- if (!LineageCache.isReusable(curr, ec))
+ if (!LineageCacheConfig.isReusable(curr, ec))
return false;
// If the inputs to * came from rbind, look for both the inputs
in cache.
@@ -668,12 +668,12 @@ public class LineageRewriteReuse
// create * lineage on top of the input of last
append
LineageItem tmp = new LineageItem("toProbe",
curr.getOpcode(), new LineageItem[] {leftSource, rightSource});
if (LineageCache.probe(tmp))
- inCache.put("lastMatrix",
LineageCache.get(tmp).getMBValue());
+ inCache.put("lastMatrix",
LineageCache.getMatrix(tmp));
// look for the appended rows in cache
if (LineageCache.probe(left.getInputs()[1]))
- inCache.put("deltaX",
LineageCache.get(left.getInputs()[1]).getMBValue());
+ inCache.put("deltaX",
LineageCache.getMatrix(left.getInputs()[1]));
if (LineageCache.probe(right.getInputs()[1]))
- inCache.put("deltaY",
LineageCache.get(right.getInputs()[1]).getMBValue());
+ inCache.put("deltaY",
LineageCache.getMatrix(right.getInputs()[1]));
}
}
return inCache.containsKey("lastMatrix") ? true : false;
@@ -681,7 +681,7 @@ public class LineageRewriteReuse
private static boolean isElementMulCbind(Instruction curr,
ExecutionContext ec, Map<String, MatrixBlock> inCache)
{
- if (!LineageCache.isReusable(curr, ec))
+ if (!LineageCacheConfig.isReusable(curr, ec))
return false;
// If the inputs to * came from cbind, look for both the inputs
in cache.
@@ -695,12 +695,12 @@ public class LineageRewriteReuse
// create * lineage on top of the input of last
append
LineageItem tmp = new LineageItem("toProbe",
curr.getOpcode(), new LineageItem[] {leftSource, rightSource});
if (LineageCache.probe(tmp))
- inCache.put("lastMatrix",
LineageCache.get(tmp).getMBValue());
+ inCache.put("lastMatrix",
LineageCache.getMatrix(tmp));
// look for the appended columns in cache
if (LineageCache.probe(left.getInputs()[1]))
- inCache.put("deltaX",
LineageCache.get(left.getInputs()[1]).getMBValue());
+ inCache.put("deltaX",
LineageCache.getMatrix(left.getInputs()[1]));
if (LineageCache.probe(right.getInputs()[1]))
- inCache.put("deltaY",
LineageCache.get(right.getInputs()[1]).getMBValue());
+ inCache.put("deltaY",
LineageCache.getMatrix(right.getInputs()[1]));
}
}
return inCache.containsKey("lastMatrix") ? true : false;
@@ -708,7 +708,7 @@ public class LineageRewriteReuse
private static boolean isAggCbind (Instruction curr, ExecutionContext
ec, Map<String, MatrixBlock> inCache)
{
- if (!LineageCache.isReusable(curr, ec)) {
+ if (!LineageCacheConfig.isReusable(curr, ec)) {
return false;
}
@@ -726,10 +726,10 @@ public class LineageRewriteReuse
LineageItem tmp = new LineageItem("toProbe",
curr.getOpcode(),
new LineageItem[] {input1,
groups, weights, fn, ngroups});
if (LineageCache.probe(tmp))
- inCache.put("lastMatrix",
LineageCache.get(tmp).getMBValue());
+ inCache.put("lastMatrix",
LineageCache.getMatrix(tmp));
// look for the appended column in cache
if (LineageCache.probe(target.getInputs()[1]))
- inCache.put("deltaX",
LineageCache.get(target.getInputs()[1]).getMBValue());
+ inCache.put("deltaX",
LineageCache.getMatrix(target.getInputs()[1]));
}
}
// return true only if the last tsmm is found
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
index 9819fe0..22740f3 100644
---
a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
@@ -35,50 +35,55 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
-public class FunctionFullReuseTest extends AutomatedTestBase {
-
+public class FunctionFullReuseTest extends AutomatedTestBase
+{
protected static final String TEST_DIR = "functions/lineage/";
- protected static final String TEST_NAME1 = "FunctionFullReuse1";
- protected static final String TEST_NAME2 = "FunctionFullReuse2";
- protected static final String TEST_NAME3 = "FunctionFullReuse3";
- protected static final String TEST_NAME4 = "FunctionFullReuse4";
- protected static final String TEST_NAME5 = "FunctionFullReuse5";
+ protected static final String TEST_NAME = "FunctionFullReuse";
+ protected static final int TEST_VARIANTS = 7;
+
protected String TEST_CLASS_DIR = TEST_DIR +
FunctionFullReuseTest.class.getSimpleName() + "/";
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1));
- addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2));
- addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3));
- addTestConfiguration(TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4));
- addTestConfiguration(TEST_NAME5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5));
+ for( int i=1; i<=TEST_VARIANTS; i++ )
+ addTestConfiguration(TEST_NAME+i, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i));
}
@Test
public void testCacheHit() {
- testLineageTrace(TEST_NAME1);
+ testLineageTrace(TEST_NAME+"1");
}
@Test
public void testCacheMiss() {
- testLineageTrace(TEST_NAME2);
+ testLineageTrace(TEST_NAME+"2");
}
@Test
public void testMultipleReturns() {
- testLineageTrace(TEST_NAME3);
+ testLineageTrace(TEST_NAME+"3");
}
@Test
public void testNestedFunc() {
- testLineageTrace(TEST_NAME4);
+ testLineageTrace(TEST_NAME+"4");
}
@Test
public void testStepLM() {
- testLineageTrace(TEST_NAME5);
- }
+ testLineageTrace(TEST_NAME+"5");
+ }
+
+ @Test
+ public void testParforIssue1() {
+ testLineageTrace(TEST_NAME+"6");
+ }
+
+ @Test
+ public void testParforIssue2() {
+ testLineageTrace(TEST_NAME+"7");
+ }
public void testLineageTrace(String testname) {
boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -132,4 +137,3 @@ public class FunctionFullReuseTest extends
AutomatedTestBase {
}
}
}
-
diff --git a/src/test/scripts/functions/lineage/FunctionFullReuse6.dml
b/src/test/scripts/functions/lineage/FunctionFullReuse6.dml
new file mode 100644
index 0000000..2b025d5
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FunctionFullReuse6.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Matrix[Double] X) return (Matrix[Double] R) {
+ y = X + X - 2 * sqrt(X) + X * X;
+ while(FALSE){}
+ R = rowSums(y)*colSums(y);
+}
+
+X = rand(rows=100, cols=10, seed=7);
+while(FALSE){}
+X = X + 1;
+
+R = matrix(0, 1, ncol(X));
+parfor(i in 1:10) {
+ R[,i] = sum(foo(X));
+}
+
+write(R, $1, format="text");
diff --git a/src/test/scripts/functions/lineage/FunctionFullReuse7.dml
b/src/test/scripts/functions/lineage/FunctionFullReuse7.dml
new file mode 100644
index 0000000..e4b64d8
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FunctionFullReuse7.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Matrix[Double] X) return (Matrix[Double] R) {
+ y = X + X - 2 * sqrt(X) + X * X;
+ while(FALSE){}
+ R = rowSums(y)*colSums(y);
+}
+
+X = rand(rows=100, cols=10, seed=7);
+while(FALSE){}
+X = X + 1;
+
+R = matrix(0, 1, ncol(X));
+parfor(i in 1:10) {
+ R[,i] = sum(foo(X[,i]));
+}
+
+write(R, $1, format="text");