This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c49e25dfe9 [SYSTEMDS-3891] Improved Stream Handling and PCA support
c49e25dfe9 is described below
commit c49e25dfe952f1440fb0de57a382ad0e1fbf12d5
Author: Jannik Lindemann <[email protected]>
AuthorDate: Mon Dec 29 10:11:49 2025 +0100
[SYSTEMDS-3891] Improved Stream Handling and PCA support
Closes #2368.
---
src/main/java/org/apache/sysds/api/DMLScript.java | 3 +
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 3 +-
.../sysds/hops/rewrite/RewriteInjectOOCTee.java | 212 +++++++++++++++------
.../controlprogram/caching/CacheableData.java | 15 +-
.../controlprogram/parfor/LocalTaskQueue.java | 2 +-
.../instructions/cp/VariableCPInstruction.java | 6 +
.../runtime/instructions/ooc/CachingStream.java | 161 +++++++++++++---
.../instructions/ooc/IndexingOOCInstruction.java | 28 +++
.../ooc/MatrixIndexingOOCInstruction.java | 83 ++++----
.../instructions/ooc/OOCEvictionManager.java | 36 +++-
.../runtime/instructions/ooc/OOCInstruction.java | 202 +++++++++++++-------
.../sysds/runtime/instructions/ooc/OOCStream.java | 26 +++
.../runtime/instructions/ooc/OOCStreamable.java | 2 -
.../runtime/instructions/ooc/OOCWatchdog.java | 96 ++++++++++
.../ooc/ParameterizedBuiltinOOCInstruction.java | 27 ++-
.../runtime/instructions/ooc/PlaybackStream.java | 66 ++++++-
.../instructions/ooc/SubscribableTaskQueue.java | 169 ++++++++++++----
.../instructions/ooc/TeeOOCInstruction.java | 42 +++-
.../apache/sysds/test/functions/ooc/PCATest.java | 123 ++++++++++++
.../sysds/test/usertest/pythonapi/StartupTest.java | 2 +
src/test/scripts/functions/ooc/PCA.dml | 28 +++
21 files changed, 1060 insertions(+), 272 deletions(-)
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java
b/src/main/java/org/apache/sysds/api/DMLScript.java
index e6becd83d1..91fb0ba3d1 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -71,6 +71,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.monitoring.FederatedMon
import
org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysds.runtime.instructions.ooc.OOCEvictionManager;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy;
@@ -497,6 +498,8 @@ public class DMLScript
ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec,
ConfigurationManager.getDMLConfig(), STATISTICS ? STATISTICS_COUNT : 0, null);
}
finally {
+ //cleanup OOC streams and cache
+ OOCEvictionManager.reset();
//cleanup scratch_space and all working dirs
cleanupHadoopExecution(ConfigurationManager.getDMLConfig());
FederatedData.clearWorkGroup();
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index c2602dba51..026c242a80 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -77,7 +77,6 @@ public class ProgramRewriter{
//add static HOP DAG rewrite rules
_dagRuleSet.add( new RewriteRemoveReadAfterWrite()
); //dependency: before blocksize
_dagRuleSet.add( new RewriteBlockSizeAndReblock()
);
- _dagRuleSet.add( new RewriteInjectOOCTee()
);
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
_dagRuleSet.add( new
RewriteRemoveUnnecessaryCasts() );
if(
OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
@@ -94,6 +93,7 @@ public class ProgramRewriter{
if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE )
_dagRuleSet.add( new
RewriteQuantizationFusedCompression() );
+
//add statement block rewrite rules
if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
_sbRuleSet.add( new
RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding
@@ -152,6 +152,7 @@ public class ProgramRewriter{
_dagRuleSet.add( new RewriteConstantFolding()
); //dependency: cse
_sbRuleSet.add( new RewriteRemoveEmptyBasicBlocks()
);
_sbRuleSet.add( new RewriteRemoveEmptyForLoops()
);
+ _sbRuleSet.add( new RewriteInjectOOCTee()
);
}
/**
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
index 54dffa263e..7abfb15d1d 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.parser.StatementBlock;
import java.util.ArrayList;
import java.util.HashMap;
@@ -49,73 +50,20 @@ import java.util.Set;
* 2. <b>Apply Rewrites (Modification):</b> Iterate over the collected
candidate and put
* {@code TeeOp}, and safely rewire the graph.
*/
-public class RewriteInjectOOCTee extends HopRewriteRule {
+public class RewriteInjectOOCTee extends StatementBlockRewriteRule {
public static boolean APPLY_ONLY_XtX_PATTERN = false;
+
+ private static final Map<String, Integer> _transientVars = new
HashMap<>();
+ private static final Map<String, List<Hop>> _transientHops = new
HashMap<>();
+ private static final Set<String> teeTransientVars = new HashSet<>();
private static final Set<Long> rewrittenHops = new HashSet<>();
private static final Map<Long, Hop> handledHop = new HashMap<>();
// Maintain a list of candidates to rewrite in the second pass
private final List<Hop> rewriteCandidates = new ArrayList<>();
-
- /**
- * Handle a generic (last-level) hop DAG with multiple roots.
- *
- * @param roots high-level operator roots
- * @param state program rewrite status
- * @return list of high-level operators
- */
- @Override
- public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state) {
- if (roots == null) {
- return null;
- }
-
- // Clear candidates for this pass
- rewriteCandidates.clear();
-
- // PASS 1: Identify candidates without modifying the graph
- for (Hop root : roots) {
- root.resetVisitStatus();
- findRewriteCandidates(root);
- }
-
- // PASS 2: Apply rewrites to identified candidates
- for (Hop candidate : rewriteCandidates) {
- applyTopDownTeeRewrite(candidate);
- }
-
- return roots;
- }
-
- /**
- * Handle a predicate hop DAG with exactly one root.
- *
- * @param root high-level operator root
- * @param state program rewrite status
- * @return high-level operator
- */
- @Override
- public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
- if (root == null) {
- return null;
- }
-
- // Clear candidates for this pass
- rewriteCandidates.clear();
-
- // PASS 1: Identify candidates without modifying the graph
- root.resetVisitStatus();
- findRewriteCandidates(root);
-
- // PASS 2: Apply rewrites to identified candidates
- for (Hop candidate : rewriteCandidates) {
- applyTopDownTeeRewrite(candidate);
- }
-
- return root;
- }
+ private boolean forceTee = false;
/**
* First pass: Find candidates for rewrite without modifying the graph.
@@ -137,6 +85,35 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
findRewriteCandidates(input);
}
+ boolean isRewriteCandidate = DMLScript.USE_OOC
+ && hop.getDataType().isMatrix()
+ && !HopRewriteUtils.isData(hop, OpOpData.TEE)
+ && hop.getParent().size() > 1
+ && (!APPLY_ONLY_XtX_PATTERN ||
isSelfTranposePattern(hop));
+
+ if (HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) &&
hop.getDataType().isMatrix()) {
+ _transientVars.compute(hop.getName(), (key, ctr) -> {
+ int incr = (isRewriteCandidate || forceTee) ? 2
: 1;
+
+ int ret = ctr == null ? 0 : ctr;
+ ret += incr;
+
+ if (ret > 1)
+ teeTransientVars.add(hop.getName());
+
+ return ret;
+ });
+
+ _transientHops.compute(hop.getName(), (key, hops) -> {
+ if (hops == null)
+ return new ArrayList<>(List.of(hop));
+ hops.add(hop);
+ return hops;
+ });
+
+ return; // We do not tee transient reads but rather
inject before TWrite or PRead as caching stream
+ }
+
// Check if this hop is a candidate for OOC Tee injection
if (DMLScript.USE_OOC
&& hop.getDataType().isMatrix()
@@ -160,11 +137,17 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
return;
}
+ int consumerCount = sharedInput.getParent().size();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Inject tee for hop " +
sharedInput.getHopID() + " ("
+ + sharedInput.getName() + "), consumers=" +
consumerCount);
+ }
+
// Take a defensive copy of consumers before modifying the graph
ArrayList<Hop> consumers = new
ArrayList<>(sharedInput.getParent());
// Create the new TeeOp with the original hop as input
- DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(),
+ DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(),
sharedInput.getDataType(), sharedInput.getValueType(),
Types.OpOpData.TEE, null,
sharedInput.getDim1(), sharedInput.getDim2(),
sharedInput.getNnz(), sharedInput.getBlocksize());
HopRewriteUtils.addChildReference(teeOp, sharedInput);
@@ -177,6 +160,11 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
// Record that we've handled this hop
handledHop.put(sharedInput.getHopID(), teeOp);
rewrittenHops.add(sharedInput.getHopID());
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Created tee hop " + teeOp.getHopID() + " -> "
+ + teeOp.getName());
+ }
}
@SuppressWarnings("unused")
@@ -196,4 +184,108 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
}
return hasTransposeConsumer && hasMatrixMultiplyConsumer;
}
+
+ @Override
+ public boolean createsSplitDag() {
+ return false;
+ }
+
+ @Override
+ public List<StatementBlock> rewriteStatementBlock(StatementBlock sb,
ProgramRewriteStatus state) {
+ if (!DMLScript.USE_OOC)
+ return List.of(sb);
+
+ rewriteSB(sb, state);
+
+ for (String tVar : teeTransientVars) {
+ List<Hop> tHops = _transientHops.get(tVar);
+
+ if (tHops == null)
+ continue;
+
+ for (Hop affectedHops : tHops) {
+ applyTopDownTeeRewrite(affectedHops);
+ }
+
+ tHops.clear();
+ }
+
+ removeRedundantTeeChains(sb);
+
+ return List.of(sb);
+ }
+
+ @Override
+ public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock>
sbs, ProgramRewriteStatus state) {
+ if (!DMLScript.USE_OOC)
+ return sbs;
+
+ for (StatementBlock sb : sbs)
+ rewriteSB(sb, state);
+
+ for (String tVar : teeTransientVars) {
+ List<Hop> tHops = _transientHops.get(tVar);
+
+ if (tHops == null)
+ continue;
+
+ for (Hop affectedHops : tHops) {
+ applyTopDownTeeRewrite(affectedHops);
+ }
+ }
+
+ for (StatementBlock sb : sbs)
+ removeRedundantTeeChains(sb);
+
+ return sbs;
+ }
+
+ private void rewriteSB(StatementBlock sb, ProgramRewriteStatus state) {
+ rewriteCandidates.clear();
+
+ if (sb.getHops() != null) {
+ for(Hop hop : sb.getHops()) {
+ hop.resetVisitStatus();
+ findRewriteCandidates(hop);
+ }
+ }
+
+ for (Hop candidate : rewriteCandidates) {
+ applyTopDownTeeRewrite(candidate);
+ }
+ }
+
+ private void removeRedundantTeeChains(StatementBlock sb) {
+ if (sb == null || sb.getHops() == null)
+ return;
+
+ Hop.resetVisitStatus(sb.getHops());
+ for (Hop hop : sb.getHops())
+ removeRedundantTeeChains(hop);
+ Hop.resetVisitStatus(sb.getHops());
+ }
+
+ private void removeRedundantTeeChains(Hop hop) {
+ if (hop.isVisited())
+ return;
+
+ ArrayList<Hop> inputs = new ArrayList<>(hop.getInput());
+ for (Hop in : inputs)
+ removeRedundantTeeChains(in);
+
+ if (HopRewriteUtils.isData(hop, OpOpData.TEE) &&
hop.getInput().size() == 1) {
+ Hop teeInput = hop.getInput().get(0);
+ if (HopRewriteUtils.isData(teeInput, OpOpData.TEE)) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Remove redundant tee hop " +
hop.getHopID()
+ + " (" + hop.getName() + ") ->
" + teeInput.getHopID()
+ + " (" + teeInput.getName() +
")");
+ }
+
HopRewriteUtils.rewireAllParentChildReferences(hop, teeInput);
+ HopRewriteUtils.removeAllChildReferences(hop);
+ }
+ }
+
+ hop.setVisited();
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 34a8aa1863..d826af89c0 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -471,12 +471,12 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
return _bcHandle != null && _bcHandle.hasBackReference();
}
- public OOCStream<IndexedMatrixValue> getStreamHandle() {
+ public synchronized OOCStream<IndexedMatrixValue> getStreamHandle() {
if( !hasStreamHandle() ) {
final SubscribableTaskQueue<IndexedMatrixValue>
_mStream = new SubscribableTaskQueue<>();
- _streamHandle = _mStream;
DataCharacteristics dc = getDataCharacteristics();
MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
+ _streamHandle = _mStream;
LongStream.range(0, dc.getNumBlocks())
.mapToObj(i ->
UtilFunctions.createIndexedMatrixBlock(src, dc, i))
.forEach( blk -> {
@@ -489,7 +489,14 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
_mStream.closeInput();
}
- return _streamHandle.getReadStream();
+ OOCStream<IndexedMatrixValue> stream =
_streamHandle.getReadStream();
+ if (!stream.hasStreamCache())
+ _streamHandle = null; // To ensure read once
+ return stream;
+ }
+
+ public OOCStreamable<IndexedMatrixValue> getStreamable() {
+ return _streamHandle;
}
/**
@@ -499,7 +506,7 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
* @return true if existing, false otherwise
*/
public boolean hasStreamHandle() {
- return _streamHandle != null && !_streamHandle.isProcessed();
+ return _streamHandle != null;
}
@SuppressWarnings({ "rawtypes", "unchecked" })
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
index 783981e0f1..50143cd0ad 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
@@ -45,7 +45,7 @@ public class LocalTaskQueue<T>
protected LinkedList<T> _data = null;
protected boolean _closedInput = false;
- private DMLRuntimeException _failure = null;
+ protected DMLRuntimeException _failure = null;
private static final Log LOG =
LogFactory.getLog(LocalTaskQueue.class.getName());
public LocalTaskQueue()
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 5dd8e55e82..afc446f747 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -46,6 +46,7 @@ import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
import org.apache.sysds.runtime.io.FileFormatProperties;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.FileFormatPropertiesHDF5;
@@ -1026,6 +1027,9 @@ public class VariableCPInstruction extends CPInstruction
implements LineageTrace
if ( dd == null )
throw new DMLRuntimeException("Unexpected error: could
not find a data object for variable name:" + getInput1().getName() + ", while
processing instruction " +this.toString());
+ if (DMLScript.USE_OOC && dd instanceof MatrixObject)
+
TeeOOCInstruction.incrRef(((MatrixObject)dd).getStreamable(), 1);
+
// remove existing variable bound to target name
Data input2_data = ec.removeVariable(getInput2().getName());
@@ -1117,6 +1121,8 @@ public class VariableCPInstruction extends CPInstruction
implements LineageTrace
public static void processRmvarInstruction( ExecutionContext ec, String
varname ) {
// remove variable from symbol table
Data dat = ec.removeVariable(varname);
+ if (DMLScript.USE_OOC && dat instanceof MatrixObject)
+ TeeOOCInstruction.incrRef(((MatrixObject)
dat).getStreamable(), -1);
//cleanup matrix data on fs/hdfs (if necessary)
if( dat != null )
ec.cleanupDataObject(dat);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
index d7c80e4de3..cdc2391151 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
@@ -24,6 +24,7 @@ import
org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList;
import java.util.HashMap;
import java.util.Map;
@@ -39,6 +40,7 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
// original live stream
private final OOCStream<IndexedMatrixValue> _source;
+ private final IntArrayList _consumptionCounts = new IntArrayList();
// stream identifier
private final long _streamId;
@@ -54,6 +56,10 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
private DMLRuntimeException _failure;
+ private boolean deletable = false;
+ private int maxConsumptionCount = 0;
+ private int cachePins = 0;
+
public CachingStream(OOCStream<IndexedMatrixValue> source) {
this(source, _streamSeq.getNextID());
}
@@ -61,23 +67,43 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
public CachingStream(OOCStream<IndexedMatrixValue> source, long
streamId) {
_source = source;
_streamId = streamId;
- source.setSubscriber(() -> {
+ source.setSubscriber(tmp -> {
try {
- boolean closed = fetchFromStream();
- Runnable[] mSubscribers = _subscribers;
+ final IndexedMatrixValue task = tmp.get();
+ int blk;
+ Runnable[] mSubscribers;
+
+ synchronized (this) {
+ if(task !=
LocalTaskQueue.NO_MORE_TASKS) {
+ if (!_cacheInProgress)
+ throw new
DMLRuntimeException("Stream is closed");
+
OOCEvictionManager.put(_streamId, _numBlocks, task);
+ if (_index != null)
+
_index.put(task.getIndexes(), _numBlocks);
+ blk = _numBlocks;
+ _numBlocks++;
+ _consumptionCounts.add(0);
+ notifyAll();
+ }
+ else {
+ _cacheInProgress = false; //
caching is complete
+ notifyAll();
+ blk = -1;
+ }
+
+ mSubscribers = _subscribers;
+ }
if(mSubscribers != null) {
for(Runnable mSubscriber : mSubscribers)
mSubscriber.run();
- if (closed) {
+ if (blk == -1) {
synchronized (this) {
_subscribers = null;
}
}
}
- } catch (InterruptedException e) {
- throw new DMLRuntimeException(e);
} catch (DMLRuntimeException e) {
// Propagate failure to subscribers
_failure = e;
@@ -98,25 +124,28 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
});
}
- private synchronized boolean fetchFromStream() throws
InterruptedException {
- if(!_cacheInProgress)
- throw new DMLRuntimeException("Stream is closed");
+ public synchronized void scheduleDeletion() {
+ deletable = true;
+ if (_cacheInProgress && maxConsumptionCount == 0)
+ throw new DMLRuntimeException("Cannot have a caching
stream with no listeners");
+ for (int i = 0; i < _consumptionCounts.size(); i++) {
+ tryDeleteBlock(i);
+ }
+ }
- IndexedMatrixValue task = _source.dequeue();
+ public String toString() {
+ return "CachingStream@" + _streamId;
+ }
- if(task != LocalTaskQueue.NO_MORE_TASKS) {
- OOCEvictionManager.put(_streamId, _numBlocks, task);
- if (_index != null)
- _index.put(task.getIndexes(), _numBlocks);
- _numBlocks++;
- notifyAll();
- return false;
- }
- else {
- _cacheInProgress = false; // caching is complete
- notifyAll();
- return true;
- }
+ private synchronized void tryDeleteBlock(int i) {
+ if (cachePins > 0)
+ return; // Block deletion is prevented
+
+ int count = _consumptionCounts.getInt(i);
+ if (count > maxConsumptionCount)
+ throw new DMLRuntimeException("Cannot have more than "
+ maxConsumptionCount + " consumptions.");
+ if (count == maxConsumptionCount)
+ OOCEvictionManager.forget(_streamId, i);
}
public synchronized IndexedMatrixValue get(int idx) throws
InterruptedException {
@@ -129,6 +158,16 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
if (_index != null) // Ensure index is up to
date
_index.putIfAbsent(out.getIndexes(),
idx);
+ int newCount = _consumptionCounts.getInt(idx)+1;
+
+ if (newCount > maxConsumptionCount)
+ throw new DMLRuntimeException("Consumer
overflow! Expected: " + maxConsumptionCount);
+
+ _consumptionCounts.set(idx, newCount);
+
+ if (deletable)
+ tryDeleteBlock(idx);
+
return out;
} else if (!_cacheInProgress)
return
(IndexedMatrixValue)LocalTaskQueue.NO_MORE_TASKS;
@@ -137,8 +176,31 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
}
}
+ public synchronized int findCachedIndex(MatrixIndexes idx) {
+ return _index.get(idx);
+ }
+
public synchronized IndexedMatrixValue findCached(MatrixIndexes idx) {
- return OOCEvictionManager.get(_streamId, _index.get(idx));
+ int mIdx = _index.get(idx);
+ int newCount = _consumptionCounts.getInt(mIdx)+1;
+ if (newCount > maxConsumptionCount)
+ throw new DMLRuntimeException("Consumer overflow in " +
_streamId + "_" + mIdx + ". Expected: " + maxConsumptionCount);
+ _consumptionCounts.set(mIdx, newCount);
+
+ IndexedMatrixValue imv = OOCEvictionManager.get(_streamId,
mIdx);
+
+ if (deletable)
+ tryDeleteBlock(mIdx);
+
+ return imv;
+ }
+
+ /**
+ * Finds a cached item without counting it as a consumption.
+ */
+ public synchronized IndexedMatrixValue peekCached(MatrixIndexes idx) {
+ int mIdx = _index.get(idx);
+ return OOCEvictionManager.get(_streamId, mIdx);
}
public synchronized void activateIndexing() {
@@ -161,12 +223,18 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
return false;
}
- @Override
- public void setSubscriber(Runnable subscriber) {
+ public void setSubscriber(Runnable subscriber, boolean incrConsumers) {
+ if (deletable)
+ throw new DMLRuntimeException("Cannot register a new
subscriber on " + this + " because has been flagged for deletion");
+
int mNumBlocks;
+ boolean cacheInProgress;
synchronized (this) {
mNumBlocks = _numBlocks;
- if (_cacheInProgress) {
+ cacheInProgress = _cacheInProgress;
+ if (incrConsumers)
+ maxConsumptionCount++;
+ if (cacheInProgress) {
int newLen = _subscribers == null ? 1 :
_subscribers.length + 1;
Runnable[] newSubscribers = new
Runnable[newLen];
@@ -181,7 +249,44 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
for (int i = 0; i < mNumBlocks; i++)
subscriber.run();
- if (!_cacheInProgress)
+ if (!cacheInProgress)
subscriber.run(); // To fetch the NO_MORE_TASK element
}
+
+ /**
+ * Artificially increase subscriber count.
+ * Only use if certain blocks are accessed more than once.
+ */
+ public synchronized void incrSubscriberCount(int count) {
+ maxConsumptionCount += count;
+ }
+
+ /**
+ * Artificially increase the processing count of a block.
+ */
+ public synchronized void incrProcessingCount(int i, int count) {
+ _consumptionCounts.set(i, _consumptionCounts.getInt(i)+count);
+
+ if (deletable)
+ tryDeleteBlock(i);
+ }
+
+ /**
+ * Force pins blocks in the cache to not be subject to block deletion.
+ */
+ public synchronized void pinStream() {
+ cachePins++;
+ }
+
+ /**
+ * Unpins the stream, allowing blocks to be deleted from cache.
+ */
+ public synchronized void unpinStream() {
+ cachePins--;
+
+ if (cachePins == 0) {
+ for (int i = 0; i < _consumptionCounts.size(); i++)
+ tryDeleteBlock(i);
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java
index 1d555da8d6..175d81d6e0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java
@@ -115,6 +115,34 @@ public abstract class IndexingOOCInstruction extends
UnaryOOCInstruction {
return (_indexRange.rowStart % _blocksize) == 0 &&
(_indexRange.colStart % _blocksize) == 0;
}
+ public int getNumConsumptions(MatrixIndexes index) {
+ long blockRow = index.getRowIndex() - 1;
+ long blockCol = index.getColumnIndex() - 1;
+
+ if(!_blockRange.isWithin(blockRow, blockCol))
+ return 0;
+
+ long blockRowStart = blockRow * _blocksize;
+ long blockRowEnd = blockRowStart + _blocksize - 1;
+ long blockColStart = blockCol * _blocksize;
+ long blockColEnd = blockColStart + _blocksize - 1;
+
+ long overlapRowStart = Math.max(_indexRange.rowStart,
blockRowStart);
+ long overlapRowEnd = Math.min(_indexRange.rowEnd,
blockRowEnd);
+ long overlapColStart = Math.max(_indexRange.colStart,
blockColStart);
+ long overlapColEnd = Math.min(_indexRange.colEnd,
blockColEnd);
+
+ if(overlapRowStart > overlapRowEnd || overlapColStart >
overlapColEnd)
+ return 0;
+
+ int outRowStart = (int) ((overlapRowStart -
_indexRange.rowStart) / _blocksize);
+ int outRowEnd = (int) ((overlapRowEnd -
_indexRange.rowStart) / _blocksize);
+ int outColStart = (int) ((overlapColStart -
_indexRange.colStart) / _blocksize);
+ int outColEnd = (int) ((overlapColEnd -
_indexRange.colStart) / _blocksize);
+
+ return (outRowEnd - outRowStart + 1) * (outColEnd -
outColStart + 1);
+ }
+
public boolean putNext(MatrixIndexes index, T data,
BiConsumer<MatrixIndexes, Sector<T>> emitter) {
long blockRow = index.getRowIndex() - 1;
long blockCol = index.getColumnIndex() - 1;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java
index a04a77677c..c7cd8c9d3f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java
@@ -33,6 +33,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.util.IndexRange;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
@@ -43,11 +44,6 @@ public class MatrixIndexingOOCInstruction extends
IndexingOOCInstruction {
super(in, rl, ru, cl, cu, out, opcode, istr);
}
- protected MatrixIndexingOOCInstruction(CPOperand lhsInput, CPOperand
rhsInput, CPOperand rl, CPOperand ru,
- CPOperand cl, CPOperand cu, CPOperand out, String opcode,
String istr) {
- super(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, istr);
- }
-
@Override
public void processInstruction(ExecutionContext ec) {
String opcode = getOpcode();
@@ -96,8 +92,9 @@ public class MatrixIndexingOOCInstruction extends
IndexingOOCInstruction {
final int outBlockCols = (int)
Math.ceil((double) (ix.colSpan() + 1) / blocksize);
final int totalBlocks = outBlockRows *
outBlockCols;
final AtomicInteger producedBlocks = new
AtomicInteger(0);
+ CompletableFuture<Void> future = new
CompletableFuture<>();
- CompletableFuture<Void> future = filterOOC(qIn,
tmp -> {
+ filterOOC(qIn, tmp -> {
MatrixIndexes inIdx = tmp.getIndexes();
long blockRow = inIdx.getRowIndex() - 1;
long blockCol = inIdx.getColumnIndex()
- 1;
@@ -124,12 +121,12 @@ public class MatrixIndexingOOCInstruction extends
IndexingOOCInstruction {
long outBlockCol = blockCol -
firstBlockCol + 1;
qOut.enqueue(new IndexedMatrixValue(new
MatrixIndexes(outBlockRow, outBlockCol), outBlock));
- if(producedBlocks.incrementAndGet() >=
totalBlocks) {
- CompletableFuture<Void> f =
futureRef.get();
- if(f != null)
- f.cancel(true);
- }
+ if(producedBlocks.incrementAndGet() >=
totalBlocks)
+ future.complete(null);
}, tmp -> {
+ if (future.isDone()) // Then we may
skip blocks and avoid submitting tasks
+ return false;
+
long blockRow =
tmp.getIndexes().getRowIndex() - 1;
long blockCol =
tmp.getIndexes().getColumnIndex() - 1;
return blockRow >= firstBlockRow &&
blockRow <= lastBlockRow && blockCol >= firstBlockCol &&
@@ -139,20 +136,23 @@ public class MatrixIndexingOOCInstruction extends
IndexingOOCInstruction {
return;
}
- final BlockAligner<IndexedBlockMeta> aligner = new
BlockAligner<>(ix, blocksize);
+ final BlockAligner<MatrixIndexes> aligner = new
BlockAligner<>(ix, blocksize);
+ final ConcurrentHashMap<MatrixIndexes, Integer>
consumptionCounts = new ConcurrentHashMap<>();
// We may need to construct our own intermediate stream
to properly manage the cached items
boolean hasIntermediateStream = !qIn.hasStreamCache();
final CachingStream cachedStream =
hasIntermediateStream ? new CachingStream(new SubscribableTaskQueue<>()) :
qOut.getStreamCache();
cachedStream.activateIndexing();
+ cachedStream.incrSubscriberCount(1); // We may require
re-consumption of blocks (up to 4 times)
+ final CompletableFuture<Void> future = new
CompletableFuture<>();
- CompletableFuture<Void> future =
filterOOC(qIn.getReadStream(), tmp -> {
+ filterOOC(qIn.getReadStream(), tmp -> {
if (hasIntermediateStream) {
// We write to an intermediate stream
to ensure that these matrix blocks are properly cached
cachedStream.getWriteStream().enqueue(tmp);
}
- boolean completed =
aligner.putNext(tmp.getIndexes(), new IndexedBlockMeta(tmp), (idx, sector) -> {
+ boolean completed =
aligner.putNext(tmp.getIndexes(), tmp.getIndexes(), (idx, sector) -> {
int targetBlockRow = (int)
(idx.getRowIndex() - 1);
int targetBlockCol = (int)
(idx.getColumnIndex() - 1);
@@ -176,18 +176,18 @@ public class MatrixIndexingOOCInstruction extends
IndexingOOCInstruction {
for(int r = 0; r < rowSegments; r++) {
for(int c = 0; c < colSegments;
c++) {
- IndexedBlockMeta ibm =
sector.get(r, c);
- if(ibm == null)
+ MatrixIndexes mIdx =
sector.get(r, c);
+ if(mIdx == null)
continue;
- IndexedMatrixValue mv =
cachedStream.findCached(ibm.idx);
+ IndexedMatrixValue mv =
cachedStream.peekCached(mIdx);
MatrixBlock srcBlock =
(MatrixBlock) mv.getValue();
if(target == null)
target = new
MatrixBlock(nRows, nCols, srcBlock.isInSparseFormat());
- long srcBlockRowStart =
(ibm.idx.getRowIndex() - 1) * blocksize;
- long srcBlockColStart =
(ibm.idx.getColumnIndex() - 1) * blocksize;
+ long srcBlockRowStart =
(mIdx.getRowIndex() - 1) * blocksize;
+ long srcBlockColStart =
(mIdx.getColumnIndex() - 1) * blocksize;
long
sliceRowStartGlobal = Math.max(targetRowStartGlobal, srcBlockRowStart);
long sliceRowEndGlobal
= Math.min(targetRowEndGlobal,
srcBlockRowStart + srcBlock.getNumRows() - 1);
@@ -205,21 +205,31 @@ public class MatrixIndexingOOCInstruction extends
IndexingOOCInstruction {
MatrixBlock sliced =
srcBlock.slice(sliceRowStart, sliceRowEnd, sliceColStart, sliceColEnd);
sliced.putInto(target,
targetRowOffset, targetColOffset, true);
+ final int
maxConsumptions = aligner.getNumConsumptions(mIdx);
+
+ Integer con =
consumptionCounts.compute(mIdx, (k, v) -> {
+ if (v == null)
+ v = 0;
+ v = v+1;
+ if (v ==
maxConsumptions)
+ return
null;
+ return v;
+ });
+
+ if (con == null)
+
cachedStream.incrProcessingCount(cachedStream.findCachedIndex(mIdx), 1);
}
}
qOut.enqueue(new
IndexedMatrixValue(idx, target));
});
- if(completed) {
- // All blocks have been processed; we
can cancel the future
- // Currently, this does not affect
processing (predicates prevent task submission anyway).
- // However, a cancelled future may
allow early file read aborts once implemented.
- CompletableFuture<Void> f =
futureRef.get();
- if(f != null)
- f.cancel(true);
- }
+ if(completed)
+ future.complete(null);
}, tmp -> {
+ if (future.isDone()) // Then we may skip blocks
and avoid submitting tasks
+ return false;
+
// Pre-filter incoming blocks to avoid
unnecessary task submission
long blockRow = tmp.getIndexes().getRowIndex()
- 1;
long blockCol =
tmp.getIndexes().getColumnIndex() - 1;
@@ -228,8 +238,15 @@ public class MatrixIndexingOOCInstruction extends
IndexingOOCInstruction {
}, () -> {
aligner.close();
qOut.closeInput();
+ }, tmp -> {
+ // If elements are not processed in an existing
caching stream, we increment the process counter to allow block deletion
+ if (!hasIntermediateStream)
+
cachedStream.incrProcessingCount(cachedStream.findCachedIndex(tmp.getIndexes()),
1);
});
futureRef.set(future);
+
+ if (hasIntermediateStream)
+ cachedStream.scheduleDeletion(); // We can
immediately delete blocks after consumption
}
//left indexing
else if(opcode.equalsIgnoreCase(Opcodes.LEFT_INDEX.toString()))
{
@@ -239,16 +256,4 @@ public class MatrixIndexingOOCInstruction extends
IndexingOOCInstruction {
throw new DMLRuntimeException(
"Invalid opcode (" + opcode + ") encountered in
MatrixIndexingOOCInstruction.");
}
-
- private static class IndexedBlockMeta {
- public final MatrixIndexes idx;
- ////public final long nrows;
- //public final long ncols;
-
- public IndexedBlockMeta(IndexedMatrixValue mv) {
- this.idx = mv.getIndexes();
- //this.nrows = mv.getValue().getNumRows();
- //this.ncols = mv.getValue().getNumColumns();
- }
- }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java
index 099a26ebd9..235afce833 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java
@@ -106,7 +106,7 @@ import java.util.concurrent.locks.ReentrantLock;
public class OOCEvictionManager {
// Configuration: OOC buffer limit as percentage of heap
- private static final double OOC_BUFFER_PERCENTAGE = 0.15 * 0.01 * 2; //
15% of heap
+ private static final double OOC_BUFFER_PERCENTAGE = 0.15; // 15% of heap
private static final double PARTITION_EVICTION_SIZE = 64 * 1024 * 1024;
// 64 MB
@@ -192,6 +192,40 @@ public class OOCEvictionManager {
LocalFileUtils.createLocalFileIfNotExist(_spillDir);
}
+ public static void reset() {
+ TeeOOCInstruction.reset();
+ if (!_cache.isEmpty()) {
+ System.err.println("There are dangling elements in the
OOC Eviction cache: " + _cache.size());
+ }
+ _size.set(0);
+ _cache.clear();
+ _spillLocations.clear();
+ _partitions.clear();
+ _partitionCounter.set(0);
+ _streamPartitions.clear();
+ }
+
+ /**
+ * Removes a block from the cache without setting its data to null.
+ */
+ public static void forget(long streamId, int blockId) {
+ BlockEntry e;
+ synchronized (_cacheLock) {
+ e = _cache.remove(streamId + "_" + blockId);
+ }
+
+ if (e != null) {
+ e.lock.lock();
+ try {
+ if (e.state == BlockState.HOT)
+ _size.addAndGet(-e.size);
+ } finally {
+ e.lock.unlock();
+ }
+ System.out.println("Removed block " + streamId + "_" +
blockId + " from cache (idx: " + (e.value != null ? e.value.getIndexes() : "?")
+ ")");
+ }
+ }
+
/**
* Store a block in the OOC cache (serialize once)
*/
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index ca13cfdb2c..80cd6b6a87 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -131,11 +131,15 @@ public abstract class OOCInstruction extends Instruction {
return new SubscribableTaskQueue<>();
}
- protected <T, R> CompletableFuture<Void> filterOOC(OOCStream<T> qIn,
Consumer<T> processor, Function<T, Boolean> predicate, Runnable finalizer) {
+ protected <T> CompletableFuture<Void> filterOOC(OOCStream<T> qIn,
Consumer<T> processor, Function<T, Boolean> predicate, Runnable finalizer) {
+ return filterOOC(qIn, processor, predicate, finalizer, null);
+ }
+
+ protected <T> CompletableFuture<Void> filterOOC(OOCStream<T> qIn,
Consumer<T> processor, Function<T, Boolean> predicate, Runnable finalizer,
Consumer<T> onNotProcessed) {
if (_inQueues == null || _outQueues == null)
throw new NotImplementedException("filterOOC requires
manual specification of all input and output streams for error propagation");
- return submitOOCTasks(qIn, processor, finalizer, predicate);
+ return submitOOCTasks(qIn, processor, finalizer, predicate,
onNotProcessed != null ? (i, tmp) -> onNotProcessed.accept(tmp) : null);
}
protected <T, R> CompletableFuture<Void> mapOOC(OOCStream<T> qIn,
OOCStream<R> qOut, Function<T, R> mapper) {
@@ -163,10 +167,16 @@ public abstract class OOCInstruction extends Instruction {
leftCache.activateIndexing();
rightCache.activateIndexing();
+ if (!explicitLeftCaching)
+ leftCache.incrSubscriberCount(1); // Prevent early
block deletion as we may read elements twice
+
+ if (!explicitRightCaching)
+ rightCache.incrSubscriberCount(1);
+
Map<P, List<MatrixIndexes>> availableLeftInput = new
ConcurrentHashMap<>();
Map<P, BroadcastedElement> availableBroadcastInput = new
ConcurrentHashMap<>();
- return submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> {
+ CompletableFuture<Void> future = submitOOCTasks(List.of(qIn,
broadcast), (i, tmp) -> {
P key = on.apply(tmp);
if (i == 0) { // qIn stream
@@ -184,11 +194,22 @@ public abstract class OOCInstruction extends Instruction {
return v;
});
} else {
+ if (!explicitLeftCaching)
+
leftCache.incrProcessingCount(leftCache.findCachedIndex(tmp.getIndexes()), 1);
// Correct for incremented subscriber count to allow block deletion
+
+ b.value = rightCache.peekCached(b.idx);
+
// Directly emit
qOut.enqueue(mapper.apply(tmp, b));
- if (b.canRelease())
+ b.value = null;
+
+ if (b.canRelease()) {
availableBroadcastInput.remove(key);
+
+ if (!explicitRightCaching)
+
rightCache.incrProcessingCount(rightCache.findCachedIndex(b.idx), 1); //
Correct for incremented subscriber count to allow block deletion
+ }
}
} else { // broadcast stream
if (explicitRightCaching)
@@ -201,16 +222,33 @@ public abstract class OOCInstruction extends Instruction {
if (queued != null) {
for(MatrixIndexes idx : queued) {
- b.value =
rightCache.findCached(b.idx);
+ b.value =
rightCache.peekCached(b.idx); // Only peek to prevent block deletion
qOut.enqueue(mapper.apply(leftCache.findCached(idx), b));
b.value = null;
}
}
- if (b.canRelease())
+ if (b.canRelease()) {
availableBroadcastInput.remove(key);
+
+ if (!explicitRightCaching)
+
rightCache.incrProcessingCount(rightCache.findCachedIndex(tmp.getIndexes()),
1); // Correct for incremented subscriber count to allow block deletion
+ }
}
- }, qOut::closeInput);
+ }, () -> {
+ availableBroadcastInput.forEach((k, v) -> {
+
rightCache.incrProcessingCount(rightCache.findCachedIndex(v.idx), 1);
+ });
+ availableBroadcastInput.clear();
+ qOut.closeInput();
+ });
+
+ if (explicitLeftCaching)
+ leftCache.scheduleDeletion();
+ if (explicitRightCaching)
+ rightCache.scheduleDeletion();
+
+ return future;
}
protected static class BroadcastedElement {
@@ -244,7 +282,7 @@ public abstract class OOCInstruction extends Instruction {
public IndexedMatrixValue getValue() {
return value;
}
- };
+ }
protected <T, R, P> CompletableFuture<Void> joinOOC(OOCStream<T> qIn1,
OOCStream<T> qIn2, OOCStream<R> qOut, BiFunction<T, T, R> mapper, Function<T,
P> on) {
return joinOOC(qIn1, qIn2, qOut, mapper, on, on);
@@ -257,12 +295,18 @@ public abstract class OOCInstruction extends Instruction {
final CompletableFuture<Void> future = new
CompletableFuture<>();
+ boolean explicitLeftCaching = !qIn1.hasStreamCache();
+ boolean explicitRightCaching = !qIn2.hasStreamCache();
+
// We need to construct our own stream to properly manage the
cached items in the hash join
- CachingStream leftCache = qIn1.hasStreamCache() ?
qIn1.getStreamCache() : new
CachingStream((SubscribableTaskQueue<IndexedMatrixValue>)qIn1); // We have to
assume this generic type for now
- CachingStream rightCache = qIn2.hasStreamCache() ?
qIn2.getStreamCache() : new
CachingStream((SubscribableTaskQueue<IndexedMatrixValue>)qIn2); // We have to
assume this generic type for now
+ CachingStream leftCache = explicitLeftCaching ? new
CachingStream((OOCStream<IndexedMatrixValue>) qIn1) : qIn1.getStreamCache();
+ CachingStream rightCache = explicitRightCaching ? new
CachingStream((OOCStream<IndexedMatrixValue>) qIn2) : qIn2.getStreamCache();
leftCache.activateIndexing();
rightCache.activateIndexing();
+ leftCache.incrSubscriberCount(1);
+ rightCache.incrSubscriberCount(1);
+
final OOCJoin<P, MatrixIndexes> join = new OOCJoin<>((idx,
left, right) -> {
T leftObj = (T) leftCache.findCached(left);
T rightObj = (T) rightCache.findCached(right);
@@ -280,36 +324,40 @@ public abstract class OOCInstruction extends Instruction {
future.complete(null);
});
+ if (explicitLeftCaching)
+ leftCache.scheduleDeletion();
+ if (explicitRightCaching)
+ rightCache.scheduleDeletion();
+
return future;
}
protected <T> CompletableFuture<Void> submitOOCTasks(final
List<OOCStream<T>> queues, BiConsumer<Integer, T> consumer, Runnable finalizer)
{
+ return submitOOCTasks(queues, consumer, finalizer, null);
+ }
+
+ protected <T> CompletableFuture<Void> submitOOCTasks(final
List<OOCStream<T>> queues, BiConsumer<Integer, T> consumer, Runnable finalizer,
BiConsumer<Integer, T> onNotProcessed) {
List<CompletableFuture<Void>> futures = new
ArrayList<>(queues.size());
for (int i = 0; i < queues.size(); i++)
futures.add(new CompletableFuture<>());
- return submitOOCTasks(queues, consumer, finalizer, futures,
null);
+ return submitOOCTasks(queues, consumer, finalizer, futures,
null, onNotProcessed);
}
- protected <T> CompletableFuture<Void> submitOOCTasks(final
List<OOCStream<T>> queues, BiConsumer<Integer, T> consumer, Runnable finalizer,
List<CompletableFuture<Void>> futures, BiFunction<Integer, T, Boolean>
predicate) {
+ protected <T> CompletableFuture<Void> submitOOCTasks(final
List<OOCStream<T>> queues, BiConsumer<Integer, T> consumer, Runnable finalizer,
List<CompletableFuture<Void>> futures, BiFunction<Integer, T, Boolean>
predicate, BiConsumer<Integer, T> onNotProcessed) {
addInStream(queues.toArray(OOCStream[]::new));
ExecutorService pool = CommonThreadPool.get();
final List<AtomicInteger> activeTaskCtrs = new
ArrayList<>(queues.size());
- final List<AtomicBoolean> streamsClosed = new
ArrayList<>(queues.size());
- for (int i = 0; i < queues.size(); i++) {
- activeTaskCtrs.add(new AtomicInteger(0));
- streamsClosed.add(new AtomicBoolean(false));
- }
+ for (int i = 0; i < queues.size(); i++)
+ activeTaskCtrs.add(new AtomicInteger(1));
- final AtomicInteger globalTaskCtr = new AtomicInteger(0);
final CompletableFuture<Void> globalFuture =
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new));
if (_outQueues == null)
_outQueues = Collections.emptySet();
final Runnable oocFinalizer = oocTask(finalizer, null,
Stream.concat(_outQueues.stream(),
_inQueues.stream()).toArray(OOCStream[]::new));
- final Object globalLock = new Object();
int i = 0;
@SuppressWarnings("unused")
@@ -319,84 +367,67 @@ public abstract class OOCInstruction extends Instruction {
for (OOCStream<T> queue : queues) {
final int k = i;
final AtomicInteger localTaskCtr =
activeTaskCtrs.get(k);
- final AtomicBoolean localStreamClosed =
streamsClosed.get(k);
final CompletableFuture<Void> localFuture =
futures.get(k);
+ final AtomicBoolean closeRaceWatchdog = new
AtomicBoolean(false);
//System.out.println("Substream (k " + k + ", id " +
streamId + ", type '" + queue.getClass().getSimpleName() + "', stream_id " +
queue.hashCode() + ")");
- queue.setSubscriber(oocTask(() -> {
- final T item = queue.dequeue();
+ queue.setSubscriber(oocTask(callback -> {
+ final T item = callback.get();
- if (predicate != null && item != null &&
!predicate.apply(k, item)) // Can get closed due to cancellation
- return;
+ if(item == null) {
+
if(!closeRaceWatchdog.compareAndSet(false, true))
+ throw new
DMLRuntimeException("Race condition observed: NO_MORE_TASKS callback has been
triggered more than once");
- synchronized (globalLock) {
- if (localFuture.isDone())
- return;
+ if(localTaskCtr.decrementAndGet() == 0)
{
+ // Then we can run the
finalization procedure already
+ localFuture.complete(null);
+ }
+ return;
+ }
- globalTaskCtr.incrementAndGet();
+ if(predicate != null && !predicate.apply(k,
item)) { // Can get closed due to cancellation
+ if(onNotProcessed != null)
+ onNotProcessed.accept(k, item);
+ return;
}
- localTaskCtr.incrementAndGet();
+ if(localFuture.isDone()) {
+ if(onNotProcessed != null)
+ onNotProcessed.accept(k, item);
+ return;
+ }
+ else {
+ localTaskCtr.incrementAndGet();
+ }
pool.submit(oocTask(() -> {
- if(item != null) {
- //System.out.println("Accept" +
((IndexedMatrixValue)item).getIndexes() + " (k " + k + ", id " + streamId +
")");
- consumer.accept(k, item);
- }
- else {
- //System.out.println("Close
substream (k " + k + ", id " + streamId + ")");
- localStreamClosed.set(true);
- }
-
- boolean runFinalizer = false;
-
- synchronized (globalLock) {
- int localTasks =
localTaskCtr.decrementAndGet();
- boolean finalizeStream =
localTasks == 0 && localStreamClosed.get();
-
- int globalTasks =
globalTaskCtr.get() - 1;
-
- if (finalizeStream ||
(globalFuture.isDone() && localTasks == 0)) {
-
localFuture.complete(null);
+ // TODO For caching streams, we have no
guarantee that item is still in memory -> NullPointer possible
+ consumer.accept(k, item);
- if
(globalFuture.isDone() && globalTasks == 0)
- runFinalizer =
true;
- }
-
- globalTaskCtr.decrementAndGet();
- }
-
- if (runFinalizer)
- oocFinalizer.run();
+ if(localTaskCtr.decrementAndGet() == 0)
+ localFuture.complete(null);
}, localFuture,
Stream.concat(_outQueues.stream(),
_inQueues.stream()).toArray(OOCStream[]::new)));
+
+ if(closeRaceWatchdog.get()) // Sanity check
+ throw new DMLRuntimeException("Race
condition observed");
}, null, Stream.concat(_outQueues.stream(),
_inQueues.stream()).toArray(OOCStream[]::new)));
i++;
}
- pool.shutdown();
-
globalFuture.whenComplete((res, e) -> {
- if (globalFuture.isCancelled() ||
globalFuture.isCompletedExceptionally())
+ if (globalFuture.isCancelled() ||
globalFuture.isCompletedExceptionally()) {
futures.forEach(f -> {
- if (!f.isDone()) {
- if (globalFuture.isCancelled()
|| globalFuture.isCompletedExceptionally())
+ if(!f.isDone()) {
+ if(globalFuture.isCancelled()
|| globalFuture.isCompletedExceptionally())
f.cancel(true);
else
f.complete(null);
}
});
-
- boolean runFinalizer;
-
- synchronized (globalLock) {
- runFinalizer = globalTaskCtr.get() == 0;
}
- if (runFinalizer)
- oocFinalizer.run();
-
- //System.out.println("Shutdown (id " + streamId + ")");
+ oocFinalizer.run();
});
return globalFuture;
}
@@ -405,8 +436,8 @@ public abstract class OOCInstruction extends Instruction {
return submitOOCTasks(List.of(queue), (i, tmp) ->
consumer.accept(tmp), finalizer);
}
- protected <T> CompletableFuture<Void> submitOOCTasks(OOCStream<T>
queue, Consumer<T> consumer, Runnable finalizer, Function<T, Boolean>
predicate) {
- return submitOOCTasks(List.of(queue), (i, tmp) ->
consumer.accept(tmp), finalizer, List.of(new CompletableFuture<Void>()), (i,
tmp) -> predicate.apply(tmp));
+ protected <T> CompletableFuture<Void> submitOOCTasks(OOCStream<T>
queue, Consumer<T> consumer, Runnable finalizer, Function<T, Boolean>
predicate, BiConsumer<Integer, T> onNotProcessed) {
+ return submitOOCTasks(List.of(queue), (i, tmp) ->
consumer.accept(tmp), finalizer, List.of(new CompletableFuture<Void>()), (i,
tmp) -> predicate.apply(tmp), onNotProcessed);
}
protected CompletableFuture<Void> submitOOCTask(Runnable r,
OOCStream<?>... queues) {
@@ -450,6 +481,31 @@ public abstract class OOCInstruction extends Instruction {
};
}
+ private <T> Consumer<OOCStream.QueueCallback<T>>
oocTask(Consumer<OOCStream.QueueCallback<T>> c, CompletableFuture<Void> future,
OOCStream<?>... queues) {
+ return callback -> {
+ try {
+ c.accept(callback);
+ }
+ catch (Exception ex) {
+ DMLRuntimeException re = ex instanceof
DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
+
+ if (_failed) // Do avoid infinite cycles
+ throw re;
+
+ _failed = true;
+
+ for (OOCStream<?> q : queues)
+ q.propagateFailure(re);
+
+ if (future != null)
+ future.completeExceptionally(re);
+
+ // Rethrow to ensure proper future handling
+ throw re;
+ }
+ };
+ }
+
/**
* Tracks blocks and their counts to enable early emission
* once all blocks for a given index are processed.
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java
index 1a12cb138b..f02c847e05 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java
@@ -22,6 +22,8 @@ package org.apache.sysds.runtime.instructions.ooc;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import java.util.function.Consumer;
+
public interface OOCStream<T> extends OOCStreamable<T> {
void enqueue(T t);
@@ -36,4 +38,28 @@ public interface OOCStream<T> extends OOCStreamable<T> {
boolean hasStreamCache();
CachingStream getStreamCache();
+
+ /**
+ * Registers a new subscriber that consumes the stream.
+ * While there is no guarantee for any specific order, the closing item
LocalTaskQueue.NO_MORE_TASKS
+ * is guaranteed to be invoked after every other item has finished
processing. Thus, the NO_MORE_TASKS
+ * callback can be used to free dependent resources and close output
streams.
+ */
+ void setSubscriber(Consumer<QueueCallback<T>> subscriber);
+
+ class QueueCallback<T> {
+ private final T _result;
+ private final DMLRuntimeException _failure;
+
+ public QueueCallback(T result, DMLRuntimeException failure) {
+ _result = result;
+ _failure = failure;
+ }
+
+ public T get() {
+ if (_failure != null)
+ throw _failure;
+ return _result;
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java
index bdc4086bdc..af2c0afa66 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java
@@ -25,6 +25,4 @@ public interface OOCStreamable<T> {
OOCStream<T> getWriteStream();
boolean isProcessed();
-
- void setSubscriber(Runnable subscriber);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java
new file mode 100644
index 0000000000..f56b0c46da
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Watchdog to help debug OOC streams/tasks that never close.
+ */
+public final class OOCWatchdog {
+ public static final boolean WATCH = false;
+ private static final ConcurrentHashMap<String, Entry> OPEN = new
ConcurrentHashMap<>();
+ private static final ScheduledExecutorService EXEC =
+ Executors.newSingleThreadScheduledExecutor(r -> {
+ Thread t = new Thread(r, "TemporaryWatchdog");
+ t.setDaemon(true);
+ return t;
+ });
+
+ private static final long STALE_MS = TimeUnit.SECONDS.toMillis(10);
+ private static final long SCAN_INTERVAL_MS =
TimeUnit.SECONDS.toMillis(10);
+
+ static {
+ EXEC.scheduleAtFixedRate(OOCWatchdog::scan, SCAN_INTERVAL_MS,
SCAN_INTERVAL_MS, TimeUnit.MILLISECONDS);
+ }
+
+ private OOCWatchdog() {
+ // no-op
+ }
+
+ public static void registerOpen(String id, String desc, String context,
OOCStream<?> stream) {
+ OPEN.put(id, new Entry(desc, context,
System.currentTimeMillis(), stream));
+ }
+
+ public static void addEvent(String id, String eventMsg) {
+ Entry e = OPEN.get(id);
+ if (e != null)
+ e.events.add(eventMsg);
+ }
+
+ public static void registerClose(String id) {
+ OPEN.remove(id);
+ }
+
+ private static void scan() {
+ long now = System.currentTimeMillis();
+ for (Map.Entry<String, Entry> e : OPEN.entrySet()) {
+ if (now - e.getValue().openedAt >= STALE_MS) {
+ if (e.getValue().events.isEmpty())
+ continue; // Probably just a stream
that has no consumer (remains to be checked why this can happen)
+ System.err.println("[TemporaryWatchdog] Still
open after " + (now - e.getValue().openedAt) + "ms: "
+ + e.getKey() + " (" + e.getValue().desc
+ ")"
+ + (e.getValue().context != null ? "
ctx=" + e.getValue().context : ""));
+ }
+ }
+ }
+
+ private static class Entry {
+ final String desc;
+ final String context;
+ final long openedAt;
+ @SuppressWarnings("unused")
+ final OOCStream<?> stream;
+ ConcurrentLinkedQueue<String> events;
+
+ Entry(String desc, String context, long openedAt, OOCStream<?>
stream) {
+ this.desc = desc;
+ this.context = context;
+ this.openedAt = openedAt;
+ this.stream = stream;
+ this.events = new ConcurrentLinkedQueue<>();
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java
index e56d32e440..d70fc3ccb9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java
@@ -20,7 +20,6 @@
package org.apache.sysds.runtime.instructions.ooc;
import org.apache.commons.lang3.NotImplementedException;
-import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -43,7 +42,6 @@ import
org.apache.sysds.runtime.matrix.operators.SimpleOperator;
import java.util.LinkedHashMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
-import java.util.concurrent.atomic.AtomicBoolean;
public class ParameterizedBuiltinOOCInstruction extends
ComputationOOCInstruction {
@@ -110,29 +108,26 @@ public class ParameterizedBuiltinOOCInstruction extends
ComputationOOCInstructio
Data finalPattern = pattern;
- AtomicBoolean found = new AtomicBoolean(false);
+ addInStream(qIn);
+ addOutStream(); // This instruction has no output stream
- MutableObject<CompletableFuture<Void>> futureRef = new
MutableObject<>();
- CompletableFuture<Void> future = submitOOCTasks(qIn,
tmp -> {
- boolean contains =
((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue());
+ CompletableFuture<Boolean> future = new
CompletableFuture<>();
- if (contains) {
- found.set(true);
+ filterOOC(qIn, tmp -> {
+ boolean contains =
((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue());
- // Now we may complete the future
- if (futureRef.getValue() != null)
-
futureRef.getValue().complete(null);
- }
- }, () -> {});
- futureRef.setValue(future);
+ if (contains)
+ future.complete(true);
+ }, tmp -> !future.isDone(), // Don't start a separate
worker if result already known
+ () -> future.complete(false)); // Then the
pattern was not found
+ boolean ret;
try {
- futureRef.getValue().get();
+ ret = future.get();
} catch (InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
}
- boolean ret = found.get();
ec.setScalarOutput(output.getName(), new
BooleanObject(ret));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java
index 6edc4ecf27..5b996da0db 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java
@@ -23,13 +23,22 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Consumer;
+
public class PlaybackStream implements OOCStream<IndexedMatrixValue>,
OOCStreamable<IndexedMatrixValue> {
private final CachingStream _streamCache;
- private int _streamIdx;
+ private final AtomicInteger _streamIdx;
+ private final AtomicInteger _taskCtr;
+ private final AtomicBoolean _subscriberSet;
public PlaybackStream(CachingStream streamCache) {
this._streamCache = streamCache;
- this._streamIdx = 0;
+ this._streamIdx = new AtomicInteger(0);
+ this._taskCtr = new AtomicInteger(1);
+ this._subscriberSet = new AtomicBoolean(false);
+ streamCache.incrSubscriberCount(1);
}
@Override
@@ -44,15 +53,29 @@ public class PlaybackStream implements
OOCStream<IndexedMatrixValue>, OOCStreama
@Override
public LocalTaskQueue<IndexedMatrixValue> toLocalTaskQueue() {
- final SubscribableTaskQueue<IndexedMatrixValue> q = new
SubscribableTaskQueue<>();
- setSubscriber(() -> q.enqueue(dequeue()));
+ final LocalTaskQueue<IndexedMatrixValue> q = new
LocalTaskQueue<>();
+ setSubscriber(val -> {
+ if (val.get() == null) {
+ q.closeInput();
+ return;
+ }
+ try {
+ q.enqueueTask(val.get());
+ }
+ catch(InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ });
return q;
}
@Override
- public synchronized IndexedMatrixValue dequeue() {
+ public IndexedMatrixValue dequeue() {
+ if (_subscriberSet.get())
+ throw new IllegalStateException("Cannot dequeue from a
playback stream if a subscriber has been set");
+
try {
- return _streamCache.get(_streamIdx++);
+ return _streamCache.get(_streamIdx.getAndIncrement());
} catch (InterruptedException e) {
throw new DMLRuntimeException(e);
}
@@ -74,8 +97,35 @@ public class PlaybackStream implements
OOCStream<IndexedMatrixValue>, OOCStreama
}
@Override
- public void setSubscriber(Runnable subscriber) {
- _streamCache.setSubscriber(subscriber);
+ public void setSubscriber(Consumer<QueueCallback<IndexedMatrixValue>>
subscriber) {
+ if (!_subscriberSet.compareAndSet(false, true))
+ throw new IllegalArgumentException("Subscriber cannot
be set multiple times");
+
+ /**
+ * To guarantee that NO_MORE_TASKS is invoked after all
subscriber calls
+ * finished, we keep track of running tasks using a task
counter.
+ */
+ _streamCache.setSubscriber(() -> {
+ try {
+ _taskCtr.incrementAndGet();
+
+ IndexedMatrixValue val;
+
+ try {
+ val =
_streamCache.get(_streamIdx.getAndIncrement());
+ } catch (InterruptedException e) {
+ throw new DMLRuntimeException(e);
+ }
+
+ if (val != null)
+ subscriber.accept(new
QueueCallback<>(val, null));
+
+ if (_taskCtr.addAndGet(val == null ? -2 : -1)
== 0)
+ subscriber.accept(new
QueueCallback<>(null, null));
+ } catch (DMLRuntimeException e) {
+ subscriber.accept(new QueueCallback<>(null, e));
+ }
+ }, false);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
index f136ffc2bb..d88cd1c5af 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
@@ -22,80 +22,173 @@ package org.apache.sysds.runtime.instructions.ooc;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import java.util.LinkedList;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Consumer;
+
public class SubscribableTaskQueue<T> extends LocalTaskQueue<T> implements
OOCStream<T> {
- private Runnable _subscriber;
- @Override
- public synchronized void enqueue(T t) {
- try {
- super.enqueueTask(t);
- }
- catch (InterruptedException e) {
- throw new DMLRuntimeException(e);
+ private final AtomicInteger _availableCtr = new AtomicInteger(1);
+ private final AtomicBoolean _closed = new AtomicBoolean(false);
+ private volatile Consumer<QueueCallback<T>> _subscriber = null;
+ private String _watchdogId;
+
+ public SubscribableTaskQueue() {
+ if (OOCWatchdog.WATCH) {
+ _watchdogId = "STQ-" + hashCode();
+ // Capture a short context to help identify origin
+ OOCWatchdog.registerOpen(_watchdogId,
"SubscribableTaskQueue@" + hashCode(), getCtxMsg(), this);
}
+ }
- if(_subscriber != null)
- _subscriber.run();
+ private String getCtxMsg() {
+ StackTraceElement[] st = new Exception().getStackTrace();
+ // Skip the first few frames (constructor,
createWritableStream, etc.)
+ StringBuilder sb = new StringBuilder();
+ int limit = Math.min(st.length, 7);
+ for(int i = 2; i < limit; i++) {
+
sb.append(st[i].getClassName()).append(".").append(st[i].getMethodName()).append(":")
+ .append(st[i].getLineNumber());
+ if(i < limit - 1)
+ sb.append(" <- ");
+ }
+ return sb.toString();
}
@Override
- public T dequeue() {
- try {
- return super.dequeueTask();
+ public void enqueue(T t) {
+ if (t == NO_MORE_TASKS)
+ throw new DMLRuntimeException("Cannot enqueue
NO_MORE_TASKS item");
+
+ int cnt = _availableCtr.incrementAndGet();
+
+ if (cnt <= 1) { // Then the queue was already closed and we
disallow further enqueues
+ _availableCtr.decrementAndGet(); // Undo increment
+ throw new DMLRuntimeException("Cannot enqueue into
closed SubscribableTaskQueue");
}
- catch (InterruptedException e) {
- throw new DMLRuntimeException(e);
+
+ Consumer<QueueCallback<T>> s = _subscriber;
+
+ if (s != null) {
+ s.accept(new QueueCallback<>(t, _failure));
+ onDeliveryFinished();
+ return;
}
+
+ synchronized (this) {
+ // Re-check that subscriber is really null to avoid
race conditions
+ if (_subscriber == null) {
+ try {
+ super.enqueueTask(t);
+ }
+ catch(InterruptedException e) {
+ throw new DMLRuntimeException(e);
+ }
+ return;
+ }
+ // Otherwise do not insert and re-schedule subscriber
invocation
+ s = _subscriber;
+ }
+
+ // Last case if due to race a subscriber has been set
+ s.accept(new QueueCallback<>(t, _failure));
+ onDeliveryFinished();
}
@Override
- public synchronized void closeInput() {
- super.closeInput();
-
- if(_subscriber != null) {
- _subscriber.run();
- _subscriber = null;
- }
+ public synchronized void enqueueTask(T t) {
+ enqueue(t);
}
@Override
- public LocalTaskQueue<T> toLocalTaskQueue() {
- return this;
+ public T dequeue() {
+ try {
+ if (OOCWatchdog.WATCH)
+ OOCWatchdog.addEvent(_watchdogId, "dequeue -- "
+ getCtxMsg());
+ T deq = super.dequeueTask();
+ if (deq != NO_MORE_TASKS)
+ onDeliveryFinished();
+ return deq;
+ }
+ catch(InterruptedException e) {
+ throw new DMLRuntimeException(e);
+ }
}
@Override
- public OOCStream<T> getReadStream() {
- return this;
+ public synchronized T dequeueTask() {
+ return dequeue();
}
@Override
- public OOCStream<T> getWriteStream() {
- return this;
+ public synchronized void closeInput() {
+ if (_closed.compareAndSet(false, true)) {
+ super.closeInput();
+ onDeliveryFinished();
+ } else {
+ throw new IllegalStateException("Multiple close input
calls");
+ }
}
@Override
- public void setSubscriber(Runnable subscriber) {
- int queueSize;
+ public void setSubscriber(Consumer<QueueCallback<T>> subscriber) {
+ if(subscriber == null)
+ throw new IllegalArgumentException("Cannot set
subscriber to null");
- synchronized (this) {
+ LinkedList<T> data;
+
+ synchronized(this) {
if(_subscriber != null)
throw new DMLRuntimeException("Cannot set
multiple subscribers");
-
_subscriber = subscriber;
- queueSize = _data.size();
- queueSize += _closedInput ? 1 : 0; // To trigger the
NO_MORE_TASK element
+ if(_failure != null)
+ throw _failure;
+ data = _data;
+ _data = new LinkedList<>();
+ }
+
+ for (T t : data) {
+ subscriber.accept(new QueueCallback<>(t, _failure));
+ onDeliveryFinished();
}
+ }
- for (int i = 0; i < queueSize; i++)
- subscriber.run();
+ @SuppressWarnings("unchecked")
+ private void onDeliveryFinished() {
+ int ctr = _availableCtr.decrementAndGet();
+
+ if (ctr == 0) {
+ Consumer<QueueCallback<T>> s = _subscriber;
+ if (s != null)
+ s.accept(new QueueCallback<>((T)
LocalTaskQueue.NO_MORE_TASKS, _failure));
+
+ if (OOCWatchdog.WATCH)
+ OOCWatchdog.registerClose(_watchdogId);
+ }
}
@Override
public synchronized void propagateFailure(DMLRuntimeException re) {
super.propagateFailure(re);
+ Consumer<QueueCallback<T>> s = _subscriber;
+ if(s != null)
+ s.accept(new QueueCallback<>(null, re));
+ }
+
+ @Override
+ public LocalTaskQueue<T> toLocalTaskQueue() {
+ return this;
+ }
+
+ @Override
+ public OOCStream<T> getReadStream() {
+ return this;
+ }
- if(_subscriber != null)
- _subscriber.run();
+ @Override
+ public OOCStream<T> getWriteStream() {
+ return this;
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
index fd80b4e6e9..aba36297e7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
@@ -25,8 +25,37 @@ import
org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import java.util.concurrent.ConcurrentHashMap;
+
public class TeeOOCInstruction extends ComputationOOCInstruction {
+ private static final ConcurrentHashMap<CachingStream, Integer> refCtr =
new ConcurrentHashMap<>();
+
+ public static void reset() {
+ if (!refCtr.isEmpty()) {
+ System.err.println("There are some dangling streams
still in the cache: " + refCtr);
+ refCtr.clear();
+ }
+ }
+
+ /**
+ * Increments the reference counter of a stream by the set amount.
+ */
+ public static void incrRef(OOCStreamable<IndexedMatrixValue> stream,
int incr) {
+ if (!(stream instanceof CachingStream))
+ return;
+
+ Integer ref = refCtr.compute((CachingStream)stream, (k, v) -> {
+ if (v == null)
+ v = 0;
+ v += incr;
+ return v <= 0 ? null : v;
+ });
+
+ if (ref == null)
+ ((CachingStream)stream).scheduleDeletion();
+ }
+
protected TeeOOCInstruction(OOCType type, CPOperand in1, CPOperand out,
String opcode, String istr) {
super(type, null, in1, out, opcode, istr);
}
@@ -45,9 +74,20 @@ public class TeeOOCInstruction extends
ComputationOOCInstruction {
MatrixObject min = ec.getMatrixObject(input1);
OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
+ CachingStream handle = qIn.hasStreamCache() ?
qIn.getStreamCache() : new CachingStream(qIn);
+
+ if (!qIn.hasStreamCache()) {
+ // We also set the input stream handle
+ min.setStreamHandle(handle);
+ incrRef(handle, 2);
+ }
+ else {
+ incrRef(handle, 1);
+ }
+
//get output and create new resettable stream
MatrixObject mo = ec.getMatrixObject(output);
- mo.setStreamHandle(new CachingStream(qIn));
+ mo.setStreamHandle(handle);
mo.setMetaData(min.getMetaData());
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java
new file mode 100644
index 0000000000..e20b7ec426
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class PCATest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "PCA";
+ private final static String TEST_DIR = "functions/ooc/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
PCATest.class.getSimpleName() + "/";
+ //private final static double eps = 1e-8;
+ private static final String INPUT_NAME_1 = "X";
+ private static final String OUTPUT_NAME_1 = "PC";
+ private static final String OUTPUT_NAME_2 = "V";
+
+ private final static int rows = 50000;
+ private final static int cols = 1000;
+ private final static int maxVal = 2;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+ addTestConfiguration(TEST_NAME1, config);
+ }
+
+ @Test
+ public void testPCA() {
+ boolean allow_opfusion = OptimizerUtils.ALLOW_OPERATOR_FUSION;
+ OptimizerUtils.ALLOW_OPERATOR_FUSION = false; // some fused ops
are not implemented yet
+ runPCATest(16);
+ OptimizerUtils.ALLOW_OPERATOR_FUSION = allow_opfusion;
+ }
+
+ private void runPCATest(int k) {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME1);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ programArgs = new String[] {"-explain", "hops",
"-stats", "-ooc", "-args", input(INPUT_NAME_1), Integer.toString(k),
output(OUTPUT_NAME_1), output(OUTPUT_NAME_2)};
+
+ // 1. Generate the data in-memory as MatrixBlock objects
+ double[][] X_data = getRandomMatrix(rows, cols, 0,
maxVal, 1, 7);
+
+ // 2. Convert the double arrays to MatrixBlock objects
+ MatrixBlock X_mb =
DataConverter.convertToMatrixBlock(X_data);
+
+ // 3. Create a binary matrix writer
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+ // 4. Write matrix A to a binary SequenceFile
+ writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1),
rows, cols, 1000, X_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 +
".mtd"), Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, cols, 1000,
X_mb.getNonZeros()), Types.FileFormat.BINARY);
+ X_data = null;
+ X_mb = null;
+
+ runTest(true, false, null, -1);
+
+ //check replace OOC op
+ //Assert.assertTrue("OOC wasn't used for replacement",
+ //
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.REPLACE));
+
+ //compare results
+
+ // rerun without ooc flag
+ programArgs = new String[] {"-explain", "hops",
"-stats", "-args", input(INPUT_NAME_1), Integer.toString(k),
output(OUTPUT_NAME_1 + "_target"), output(OUTPUT_NAME_2 + "_target")};
+ runTest(true, false, null, -1);
+
+ // compare matrices
+ /*MatrixBlock ret1 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_1),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ MatrixBlock ret2 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_1 + "_target"),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ TestUtils.compareMatrices(ret1, ret2, eps);
+
+ MatrixBlock ret2_1 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_2),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ MatrixBlock ret2_2 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_2 + "_target"),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ TestUtils.compareMatrices(ret2_1, ret2_2, eps);*/
+ }
+ catch(IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
index 4365dd629f..6efa72d6ae 100644
--- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
+++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
@@ -304,6 +304,8 @@ public class StartupTest {
script.startWritingColToPipe(0, null, 0);
}
private static class ExitCalled extends RuntimeException implements
PythonDMLScript.ExitHandler {
+ private static final long serialVersionUID =
-4247240099965056602L;
+
@Override
public void exit(int status) {
throw this;
diff --git a/src/test/scripts/functions/ooc/PCA.dml
b/src/test/scripts/functions/ooc/PCA.dml
new file mode 100644
index 0000000000..567d701ec0
--- /dev/null
+++ b/src/test/scripts/functions/ooc/PCA.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+k = $2;
+
+[PC, V] = pca(X=X, K=k)
+
+write(PC, $3, format="binary");
+write(V, $4, format="binary");