This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c49e25dfe9 [SYSTEMDS-3891] Improved Stream Handling and PCA support
c49e25dfe9 is described below

commit c49e25dfe952f1440fb0de57a382ad0e1fbf12d5
Author: Jannik Lindemann <[email protected]>
AuthorDate: Mon Dec 29 10:11:49 2025 +0100

    [SYSTEMDS-3891] Improved Stream Handling and PCA support
    
    Closes #2368.
---
 src/main/java/org/apache/sysds/api/DMLScript.java  |   3 +
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |   3 +-
 .../sysds/hops/rewrite/RewriteInjectOOCTee.java    | 212 +++++++++++++++------
 .../controlprogram/caching/CacheableData.java      |  15 +-
 .../controlprogram/parfor/LocalTaskQueue.java      |   2 +-
 .../instructions/cp/VariableCPInstruction.java     |   6 +
 .../runtime/instructions/ooc/CachingStream.java    | 161 +++++++++++++---
 .../instructions/ooc/IndexingOOCInstruction.java   |  28 +++
 .../ooc/MatrixIndexingOOCInstruction.java          |  83 ++++----
 .../instructions/ooc/OOCEvictionManager.java       |  36 +++-
 .../runtime/instructions/ooc/OOCInstruction.java   | 202 +++++++++++++-------
 .../sysds/runtime/instructions/ooc/OOCStream.java  |  26 +++
 .../runtime/instructions/ooc/OOCStreamable.java    |   2 -
 .../runtime/instructions/ooc/OOCWatchdog.java      |  96 ++++++++++
 .../ooc/ParameterizedBuiltinOOCInstruction.java    |  27 ++-
 .../runtime/instructions/ooc/PlaybackStream.java   |  66 ++++++-
 .../instructions/ooc/SubscribableTaskQueue.java    | 169 ++++++++++++----
 .../instructions/ooc/TeeOOCInstruction.java        |  42 +++-
 .../apache/sysds/test/functions/ooc/PCATest.java   | 123 ++++++++++++
 .../sysds/test/usertest/pythonapi/StartupTest.java |   2 +
 src/test/scripts/functions/ooc/PCA.dml             |  28 +++
 21 files changed, 1060 insertions(+), 272 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index e6becd83d1..91fb0ba3d1 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -71,6 +71,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.monitoring.FederatedMon
 import 
org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysds.runtime.instructions.ooc.OOCEvictionManager;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy;
@@ -497,6 +498,8 @@ public class DMLScript
                        ScriptExecutorUtils.executeRuntimeProgram(rtprog, ec, 
ConfigurationManager.getDMLConfig(), STATISTICS ? STATISTICS_COUNT : 0, null);
                }
                finally {
+                       //cleanup OOC streams and cache
+                       OOCEvictionManager.reset();
                        //cleanup scratch_space and all working dirs
                        
cleanupHadoopExecution(ConfigurationManager.getDMLConfig());
                        FederatedData.clearWorkGroup();
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index c2602dba51..026c242a80 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -77,7 +77,6 @@ public class ProgramRewriter{
                        //add static HOP DAG rewrite rules
                        _dagRuleSet.add(     new RewriteRemoveReadAfterWrite()  
             ); //dependency: before blocksize
                        _dagRuleSet.add(     new RewriteBlockSizeAndReblock()   
             );
-                       _dagRuleSet.add(     new RewriteInjectOOCTee()          
             );
                        if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
                                _dagRuleSet.add( new 
RewriteRemoveUnnecessaryCasts()             );
                        if( 
OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
@@ -94,6 +93,7 @@ public class ProgramRewriter{
                        if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE )
                                _dagRuleSet.add( new 
RewriteQuantizationFusedCompression()       );
 
+
                        //add statement block rewrite rules
                        if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )
                                _sbRuleSet.add(  new 
RewriteRemoveUnnecessaryBranches()          ); //dependency: constant folding
@@ -152,6 +152,7 @@ public class ProgramRewriter{
                        _dagRuleSet.add( new RewriteConstantFolding()           
         ); //dependency: cse
                _sbRuleSet.add(  new RewriteRemoveEmptyBasicBlocks()            
     );
                _sbRuleSet.add(  new RewriteRemoveEmptyForLoops()               
     );
+               _sbRuleSet.add(      new RewriteInjectOOCTee()                  
     );
        }
        
        /**
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
index 54dffa263e..7abfb15d1d 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
@@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.OpOpData;
 import org.apache.sysds.hops.DataOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.parser.StatementBlock;
 
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -49,73 +50,20 @@ import java.util.Set;
  * 2. <b>Apply Rewrites (Modification):</b> Iterate over the collected 
candidate and put
  * {@code TeeOp}, and safely rewire the graph.
  */
-public class RewriteInjectOOCTee extends HopRewriteRule {
+public class RewriteInjectOOCTee extends StatementBlockRewriteRule {
 
        public static boolean APPLY_ONLY_XtX_PATTERN = false;
+
+       private static final Map<String, Integer> _transientVars = new 
HashMap<>();
+       private static final Map<String, List<Hop>> _transientHops = new 
HashMap<>();
+       private static final Set<String> teeTransientVars = new HashSet<>();
        
        private static final Set<Long> rewrittenHops = new HashSet<>();
        private static final Map<Long, Hop> handledHop = new HashMap<>();
 
        // Maintain a list of candidates to rewrite in the second pass
        private final List<Hop> rewriteCandidates = new ArrayList<>();
-
-       /**
-        * Handle a generic (last-level) hop DAG with multiple roots.
-        *
-        * @param roots high-level operator roots
-        * @param state program rewrite status
-        * @return list of high-level operators
-        */
-       @Override
-       public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
-               if (roots == null) {
-                       return null;
-               }
-
-               // Clear candidates for this pass
-               rewriteCandidates.clear();
-
-               // PASS 1: Identify candidates without modifying the graph
-               for (Hop root : roots) {
-                       root.resetVisitStatus();
-                       findRewriteCandidates(root);
-               }
-
-               // PASS 2: Apply rewrites to identified candidates
-               for (Hop candidate : rewriteCandidates) {
-                       applyTopDownTeeRewrite(candidate);
-               }
-
-               return roots;
-       }
-
-       /**
-        * Handle a predicate hop DAG with exactly one root.
-        *
-        * @param root  high-level operator root
-        * @param state program rewrite status
-        * @return high-level operator
-        */
-       @Override
-       public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
-               if (root == null) {
-                       return null;
-               }
-
-               // Clear candidates for this pass
-               rewriteCandidates.clear();
-
-               // PASS 1: Identify candidates without modifying the graph
-               root.resetVisitStatus();
-               findRewriteCandidates(root);
-
-               // PASS 2: Apply rewrites to identified candidates
-               for (Hop candidate : rewriteCandidates) {
-                       applyTopDownTeeRewrite(candidate);
-               }
-
-               return root;
-       }
+       private boolean forceTee = false;
 
        /**
         * First pass: Find candidates for rewrite without modifying the graph.
@@ -137,6 +85,35 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
                        findRewriteCandidates(input);
                }
 
+               boolean isRewriteCandidate = DMLScript.USE_OOC
+                       && hop.getDataType().isMatrix()
+                       && !HopRewriteUtils.isData(hop, OpOpData.TEE)
+                       && hop.getParent().size() > 1
+                       && (!APPLY_ONLY_XtX_PATTERN || 
isSelfTranposePattern(hop));
+
+               if (HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) && 
hop.getDataType().isMatrix()) {
+                       _transientVars.compute(hop.getName(), (key, ctr) -> {
+                               int incr = (isRewriteCandidate || forceTee) ? 2 
: 1;
+
+                               int ret = ctr == null ? 0 : ctr;
+                               ret += incr;
+
+                               if (ret > 1)
+                                       teeTransientVars.add(hop.getName());
+
+                               return ret;
+                       });
+
+                       _transientHops.compute(hop.getName(), (key, hops) -> {
+                               if (hops == null)
+                                       return new ArrayList<>(List.of(hop));
+                               hops.add(hop);
+                               return hops;
+                       });
+
+                       return; // We do not tee transient reads but rather 
inject before TWrite or PRead as caching stream
+               }
+
                // Check if this hop is a candidate for OOC Tee injection
                if (DMLScript.USE_OOC 
                        && hop.getDataType().isMatrix()
@@ -160,11 +137,17 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
                        return;
                }
 
+               int consumerCount = sharedInput.getParent().size();
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug("Inject tee for hop " + 
sharedInput.getHopID() + " ("
+                               + sharedInput.getName() + "), consumers=" + 
consumerCount);
+               }
+
                // Take a defensive copy of consumers before modifying the graph
                ArrayList<Hop> consumers = new 
ArrayList<>(sharedInput.getParent());
 
                // Create the new TeeOp with the original hop as input
-               DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(), 
+               DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(),
                        sharedInput.getDataType(), sharedInput.getValueType(), 
Types.OpOpData.TEE, null,
                        sharedInput.getDim1(), sharedInput.getDim2(), 
sharedInput.getNnz(), sharedInput.getBlocksize());
                HopRewriteUtils.addChildReference(teeOp, sharedInput);
@@ -177,6 +160,11 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
                // Record that we've handled this hop
                handledHop.put(sharedInput.getHopID(), teeOp);
                rewrittenHops.add(sharedInput.getHopID());
+
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug("Created tee hop " + teeOp.getHopID() + " -> "
+                               + teeOp.getName());
+               }
        }
 
        @SuppressWarnings("unused")
@@ -196,4 +184,108 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
                }
                return hasTransposeConsumer &&  hasMatrixMultiplyConsumer;
        }
+
+       @Override
+       public boolean createsSplitDag() {
+               return false;
+       }
+
+       @Override
+       public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus state) {
+               if (!DMLScript.USE_OOC)
+                       return List.of(sb);
+
+               rewriteSB(sb, state);
+
+               for (String tVar : teeTransientVars) {
+                       List<Hop> tHops = _transientHops.get(tVar);
+
+                       if (tHops == null)
+                               continue;
+
+                       for (Hop affectedHops : tHops) {
+                               applyTopDownTeeRewrite(affectedHops);
+                       }
+
+                       tHops.clear();
+               }
+
+               removeRedundantTeeChains(sb);
+
+               return List.of(sb);
+       }
+
+       @Override
+       public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> 
sbs, ProgramRewriteStatus state) {
+               if (!DMLScript.USE_OOC)
+                       return sbs;
+
+               for (StatementBlock sb : sbs)
+                       rewriteSB(sb, state);
+
+               for (String tVar : teeTransientVars) {
+                       List<Hop> tHops = _transientHops.get(tVar);
+
+                       if (tHops == null)
+                               continue;
+
+                       for (Hop affectedHops : tHops) {
+                               applyTopDownTeeRewrite(affectedHops);
+                       }
+               }
+
+               for (StatementBlock sb : sbs)
+                       removeRedundantTeeChains(sb);
+
+               return sbs;
+       }
+
+       private void rewriteSB(StatementBlock sb, ProgramRewriteStatus state) {
+               rewriteCandidates.clear();
+
+               if (sb.getHops() != null) {
+                       for(Hop hop : sb.getHops()) {
+                               hop.resetVisitStatus();
+                               findRewriteCandidates(hop);
+                       }
+               }
+
+               for (Hop candidate : rewriteCandidates) {
+                       applyTopDownTeeRewrite(candidate);
+               }
+       }
+
+       private void removeRedundantTeeChains(StatementBlock sb) {
+               if (sb == null || sb.getHops() == null)
+                       return;
+
+               Hop.resetVisitStatus(sb.getHops());
+               for (Hop hop : sb.getHops())
+                       removeRedundantTeeChains(hop);
+               Hop.resetVisitStatus(sb.getHops());
+       }
+
+       private void removeRedundantTeeChains(Hop hop) {
+               if (hop.isVisited())
+                       return;
+
+               ArrayList<Hop> inputs = new ArrayList<>(hop.getInput());
+               for (Hop in : inputs)
+                       removeRedundantTeeChains(in);
+
+               if (HopRewriteUtils.isData(hop, OpOpData.TEE) && 
hop.getInput().size() == 1) {
+                       Hop teeInput = hop.getInput().get(0);
+                       if (HopRewriteUtils.isData(teeInput, OpOpData.TEE)) {
+                               if (LOG.isDebugEnabled()) {
+                                       LOG.debug("Remove redundant tee hop " + 
hop.getHopID()
+                                               + " (" + hop.getName() + ") -> 
" + teeInput.getHopID()
+                                               + " (" + teeInput.getName() + 
")");
+                               }
+                               
HopRewriteUtils.rewireAllParentChildReferences(hop, teeInput);
+                               HopRewriteUtils.removeAllChildReferences(hop);
+                       }
+               }
+
+               hop.setVisited();
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 34a8aa1863..d826af89c0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -471,12 +471,12 @@ public abstract class CacheableData<T extends 
CacheBlock<?>> extends Data
                return  _bcHandle != null && _bcHandle.hasBackReference();
        }
        
-       public OOCStream<IndexedMatrixValue> getStreamHandle() {
+       public synchronized OOCStream<IndexedMatrixValue> getStreamHandle() {
                if( !hasStreamHandle() ) {
                        final SubscribableTaskQueue<IndexedMatrixValue> 
_mStream = new SubscribableTaskQueue<>();
-                       _streamHandle = _mStream;
                        DataCharacteristics dc = getDataCharacteristics();
                        MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
+                       _streamHandle = _mStream;
                        LongStream.range(0, dc.getNumBlocks())
                                .mapToObj(i -> 
UtilFunctions.createIndexedMatrixBlock(src, dc, i))
                                .forEach( blk -> {
@@ -489,7 +489,14 @@ public abstract class CacheableData<T extends 
CacheBlock<?>> extends Data
                        _mStream.closeInput();
                }
                
-               return _streamHandle.getReadStream();
+               OOCStream<IndexedMatrixValue> stream = 
_streamHandle.getReadStream();
+               if (!stream.hasStreamCache())
+                       _streamHandle = null; // To ensure read once
+               return stream;
+       }
+
+       public OOCStreamable<IndexedMatrixValue> getStreamable() {
+               return _streamHandle;
        }
        
        /**
@@ -499,7 +506,7 @@ public abstract class CacheableData<T extends 
CacheBlock<?>> extends Data
         * @return true if existing, false otherwise
         */
        public boolean hasStreamHandle() {
-               return _streamHandle != null && !_streamHandle.isProcessed();
+               return _streamHandle != null;
        } 
 
        @SuppressWarnings({ "rawtypes", "unchecked" })
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
index 783981e0f1..50143cd0ad 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
@@ -45,7 +45,7 @@ public class LocalTaskQueue<T>
        
        protected LinkedList<T>  _data        = null;
        protected boolean          _closedInput = false;
-       private DMLRuntimeException _failure = null;
+       protected DMLRuntimeException _failure = null;
        private static final Log LOG = 
LogFactory.getLog(LocalTaskQueue.class.getName());
        
        public LocalTaskQueue()
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 5dd8e55e82..afc446f747 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -46,6 +46,7 @@ import org.apache.sysds.runtime.data.TensorBlock;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
 import org.apache.sysds.runtime.io.FileFormatProperties;
 import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
 import org.apache.sysds.runtime.io.FileFormatPropertiesHDF5;
@@ -1026,6 +1027,9 @@ public class VariableCPInstruction extends CPInstruction 
implements LineageTrace
                if ( dd == null )
                        throw new DMLRuntimeException("Unexpected error: could 
not find a data object for variable name:" + getInput1().getName() + ", while 
processing instruction " +this.toString());
 
+               if (DMLScript.USE_OOC && dd instanceof MatrixObject)
+                       
TeeOOCInstruction.incrRef(((MatrixObject)dd).getStreamable(), 1);
+
                // remove existing variable bound to target name
                Data input2_data = ec.removeVariable(getInput2().getName());
 
@@ -1117,6 +1121,8 @@ public class VariableCPInstruction extends CPInstruction 
implements LineageTrace
        public static void processRmvarInstruction( ExecutionContext ec, String 
varname ) {
                // remove variable from symbol table
                Data dat = ec.removeVariable(varname);
+               if (DMLScript.USE_OOC && dat instanceof MatrixObject)
+                       TeeOOCInstruction.incrRef(((MatrixObject) 
dat).getStreamable(), -1);
                //cleanup matrix data on fs/hdfs (if necessary)
                if( dat != null )
                        ec.cleanupDataObject(dat);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
index d7c80e4de3..cdc2391151 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
@@ -24,6 +24,7 @@ import 
org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList;
 
 import java.util.HashMap;
 import java.util.Map;
@@ -39,6 +40,7 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
 
        // original live stream
        private final OOCStream<IndexedMatrixValue> _source;
+       private final IntArrayList _consumptionCounts = new IntArrayList();
 
        // stream identifier
        private final long _streamId;
@@ -54,6 +56,10 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
 
        private DMLRuntimeException _failure;
 
+       private boolean deletable = false;
+       private int maxConsumptionCount = 0;
+       private int cachePins = 0;
+
        public CachingStream(OOCStream<IndexedMatrixValue> source) {
                this(source, _streamSeq.getNextID());
        }
@@ -61,23 +67,43 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
        public CachingStream(OOCStream<IndexedMatrixValue> source, long 
streamId) {
                _source = source;
                _streamId = streamId;
-               source.setSubscriber(() -> {
+               source.setSubscriber(tmp -> {
                        try {
-                               boolean closed = fetchFromStream();
-                               Runnable[] mSubscribers = _subscribers;
+                               final IndexedMatrixValue task = tmp.get();
+                               int blk;
+                               Runnable[] mSubscribers;
+
+                               synchronized (this) {
+                                       if(task != 
LocalTaskQueue.NO_MORE_TASKS) {
+                                               if (!_cacheInProgress)
+                                                       throw new 
DMLRuntimeException("Stream is closed");
+                                               
OOCEvictionManager.put(_streamId, _numBlocks, task);
+                                               if (_index != null)
+                                                       
_index.put(task.getIndexes(), _numBlocks);
+                                               blk = _numBlocks;
+                                               _numBlocks++;
+                                               _consumptionCounts.add(0);
+                                               notifyAll();
+                                       }
+                                       else {
+                                               _cacheInProgress = false; // 
caching is complete
+                                               notifyAll();
+                                               blk = -1;
+                                       }
+
+                                       mSubscribers = _subscribers;
+                               }
 
                                if(mSubscribers != null) {
                                        for(Runnable mSubscriber : mSubscribers)
                                                mSubscriber.run();
 
-                                       if (closed) {
+                                       if (blk == -1) {
                                                synchronized (this) {
                                                        _subscribers = null;
                                                }
                                        }
                                }
-                       } catch (InterruptedException e) {
-                               throw new DMLRuntimeException(e);
                        } catch (DMLRuntimeException e) {
                                // Propagate failure to subscribers
                                _failure = e;
@@ -98,25 +124,28 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
                });
        }
 
-       private synchronized boolean fetchFromStream() throws 
InterruptedException {
-               if(!_cacheInProgress)
-                       throw new DMLRuntimeException("Stream is closed");
+       public synchronized void scheduleDeletion() {
+               deletable = true;
+               if (_cacheInProgress && maxConsumptionCount == 0)
+                       throw new DMLRuntimeException("Cannot have a caching 
stream with no listeners");
+               for (int i = 0; i < _consumptionCounts.size(); i++) {
+                       tryDeleteBlock(i);
+               }
+       }
 
-               IndexedMatrixValue task = _source.dequeue();
+       public String toString() {
+               return "CachingStream@" + _streamId;
+       }
 
-               if(task != LocalTaskQueue.NO_MORE_TASKS) {
-                       OOCEvictionManager.put(_streamId, _numBlocks, task);
-                       if (_index != null)
-                               _index.put(task.getIndexes(), _numBlocks);
-                       _numBlocks++;
-                       notifyAll();
-                       return false;
-               }
-               else {
-                       _cacheInProgress = false; // caching is complete
-                       notifyAll();
-                       return true;
-               }
+       private synchronized void tryDeleteBlock(int i) {
+               if (cachePins > 0)
+                       return; // Block deletion is prevented
+
+               int count = _consumptionCounts.getInt(i);
+               if (count > maxConsumptionCount)
+                       throw new DMLRuntimeException("Cannot have more than " 
+ maxConsumptionCount + " consumptions.");
+               if (count == maxConsumptionCount)
+                       OOCEvictionManager.forget(_streamId, i);
        }
 
        public synchronized IndexedMatrixValue get(int idx) throws 
InterruptedException {
@@ -129,6 +158,16 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
                                if (_index != null) // Ensure index is up to 
date
                                        _index.putIfAbsent(out.getIndexes(), 
idx);
 
+                               int newCount = _consumptionCounts.getInt(idx)+1;
+
+                               if (newCount > maxConsumptionCount)
+                                       throw new DMLRuntimeException("Consumer 
overflow! Expected: " + maxConsumptionCount);
+
+                               _consumptionCounts.set(idx, newCount);
+
+                               if (deletable)
+                                       tryDeleteBlock(idx);
+
                                return out;
                        } else if (!_cacheInProgress)
                                return 
(IndexedMatrixValue)LocalTaskQueue.NO_MORE_TASKS;
@@ -137,8 +176,31 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
                }
        }
 
+       public synchronized int findCachedIndex(MatrixIndexes idx) {
+               return _index.get(idx);
+       }
+
        public synchronized IndexedMatrixValue findCached(MatrixIndexes idx) {
-               return OOCEvictionManager.get(_streamId, _index.get(idx));
+               int mIdx = _index.get(idx);
+               int newCount = _consumptionCounts.getInt(mIdx)+1;
+               if (newCount > maxConsumptionCount)
+                       throw new DMLRuntimeException("Consumer overflow in " + 
_streamId + "_" + mIdx + ". Expected: " + maxConsumptionCount);
+               _consumptionCounts.set(mIdx, newCount);
+
+               IndexedMatrixValue imv = OOCEvictionManager.get(_streamId, 
mIdx);
+
+               if (deletable)
+                       tryDeleteBlock(mIdx);
+
+               return imv;
+       }
+
+       /**
+        * Finds a cached item without counting it as a consumption.
+        */
+       public synchronized IndexedMatrixValue peekCached(MatrixIndexes idx) {
+               int mIdx = _index.get(idx);
+               return OOCEvictionManager.get(_streamId, mIdx);
        }
 
        public synchronized void activateIndexing() {
@@ -161,12 +223,18 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
                return false;
        }
 
-       @Override
-       public void setSubscriber(Runnable subscriber) {
+       public void setSubscriber(Runnable subscriber, boolean incrConsumers) {
+               if (deletable)
+                       throw new DMLRuntimeException("Cannot register a new 
subscriber on " + this + " because has been flagged for deletion");
+
                int mNumBlocks;
+               boolean cacheInProgress;
                synchronized (this) {
                        mNumBlocks = _numBlocks;
-                       if (_cacheInProgress) {
+                       cacheInProgress = _cacheInProgress;
+                       if (incrConsumers)
+                               maxConsumptionCount++;
+                       if (cacheInProgress) {
                                int newLen = _subscribers == null ? 1 : 
_subscribers.length + 1;
                                Runnable[] newSubscribers = new 
Runnable[newLen];
 
@@ -181,7 +249,44 @@ public class CachingStream implements 
OOCStreamable<IndexedMatrixValue> {
                for (int i = 0; i < mNumBlocks; i++)
                        subscriber.run();
 
-               if (!_cacheInProgress)
+               if (!cacheInProgress)
                        subscriber.run(); // To fetch the NO_MORE_TASK element
        }
+
+       /**
+        * Artificially increase subscriber count.
+        * Only use if certain blocks are accessed more than once.
+        */
+       public synchronized void incrSubscriberCount(int count) {
+               maxConsumptionCount += count;
+       }
+
+       /**
+        * Artificially increase the processing count of a block.
+        */
+       public synchronized void incrProcessingCount(int i, int count) {
+               _consumptionCounts.set(i, _consumptionCounts.getInt(i)+count);
+
+               if (deletable)
+                       tryDeleteBlock(i);
+       }
+
+       /**
+        * Force pins blocks in the cache to not be subject to block deletion.
+        */
+       public synchronized void pinStream() {
+               cachePins++;
+       }
+
+       /**
+        * Unpins the stream, allowing blocks to be deleted from cache.
+        */
+       public synchronized void unpinStream() {
+               cachePins--;
+
+               if (cachePins == 0) {
+                       for (int i = 0; i < _consumptionCounts.size(); i++)
+                               tryDeleteBlock(i);
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java
index 1d555da8d6..175d81d6e0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/IndexingOOCInstruction.java
@@ -115,6 +115,34 @@ public abstract class IndexingOOCInstruction extends 
UnaryOOCInstruction {
                        return (_indexRange.rowStart % _blocksize) == 0 && 
(_indexRange.colStart % _blocksize) == 0;
                }
 
+               public int getNumConsumptions(MatrixIndexes index) {
+                       long blockRow = index.getRowIndex() - 1;
+                       long blockCol = index.getColumnIndex() - 1;
+
+                       if(!_blockRange.isWithin(blockRow, blockCol))
+                               return 0;
+
+                       long blockRowStart = blockRow * _blocksize;
+                       long blockRowEnd = blockRowStart + _blocksize - 1;
+                       long blockColStart = blockCol * _blocksize;
+                       long blockColEnd = blockColStart + _blocksize - 1;
+
+                       long overlapRowStart = Math.max(_indexRange.rowStart, 
blockRowStart);
+                       long overlapRowEnd = Math.min(_indexRange.rowEnd, 
blockRowEnd);
+                       long overlapColStart = Math.max(_indexRange.colStart, 
blockColStart);
+                       long overlapColEnd = Math.min(_indexRange.colEnd, 
blockColEnd);
+
+                       if(overlapRowStart > overlapRowEnd || overlapColStart > 
overlapColEnd)
+                               return 0;
+
+                       int outRowStart = (int) ((overlapRowStart - 
_indexRange.rowStart) / _blocksize);
+                       int outRowEnd = (int) ((overlapRowEnd - 
_indexRange.rowStart) / _blocksize);
+                       int outColStart = (int) ((overlapColStart - 
_indexRange.colStart) / _blocksize);
+                       int outColEnd = (int) ((overlapColEnd - 
_indexRange.colStart) / _blocksize);
+
+                       return (outRowEnd - outRowStart + 1) * (outColEnd - 
outColStart + 1);
+               }
+
                public boolean putNext(MatrixIndexes index, T data, 
BiConsumer<MatrixIndexes, Sector<T>> emitter) {
                        long blockRow = index.getRowIndex() - 1;
                        long blockCol = index.getColumnIndex() - 1;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java
index a04a77677c..c7cd8c9d3f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixIndexingOOCInstruction.java
@@ -33,6 +33,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.util.IndexRange;
 
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -43,11 +44,6 @@ public class MatrixIndexingOOCInstruction extends 
IndexingOOCInstruction {
                super(in, rl, ru, cl, cu, out, opcode, istr);
        }
 
-       protected MatrixIndexingOOCInstruction(CPOperand lhsInput, CPOperand 
rhsInput, CPOperand rl, CPOperand ru,
-               CPOperand cl, CPOperand cu, CPOperand out, String opcode, 
String istr) {
-               super(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, istr);
-       }
-
        @Override
        public void processInstruction(ExecutionContext ec) {
                String opcode = getOpcode();
@@ -96,8 +92,9 @@ public class MatrixIndexingOOCInstruction extends 
IndexingOOCInstruction {
                                final int outBlockCols = (int) 
Math.ceil((double) (ix.colSpan() + 1) / blocksize);
                                final int totalBlocks = outBlockRows * 
outBlockCols;
                                final AtomicInteger producedBlocks = new 
AtomicInteger(0);
+                               CompletableFuture<Void> future = new  
CompletableFuture<>();
 
-                               CompletableFuture<Void> future = filterOOC(qIn, 
tmp -> {
+                               filterOOC(qIn, tmp -> {
                                        MatrixIndexes inIdx = tmp.getIndexes();
                                        long blockRow = inIdx.getRowIndex() - 1;
                                        long blockCol = inIdx.getColumnIndex() 
- 1;
@@ -124,12 +121,12 @@ public class MatrixIndexingOOCInstruction extends 
IndexingOOCInstruction {
                                        long outBlockCol = blockCol - 
firstBlockCol + 1;
                                        qOut.enqueue(new IndexedMatrixValue(new 
MatrixIndexes(outBlockRow, outBlockCol), outBlock));
 
-                                       if(producedBlocks.incrementAndGet() >= 
totalBlocks) {
-                                               CompletableFuture<Void> f = 
futureRef.get();
-                                               if(f != null)
-                                                       f.cancel(true);
-                                       }
+                                       if(producedBlocks.incrementAndGet() >= 
totalBlocks)
+                                               future.complete(null);
                                }, tmp -> {
+                                       if (future.isDone()) // Then we may 
skip blocks and avoid submitting tasks
+                                               return false;
+
                                        long blockRow = 
tmp.getIndexes().getRowIndex() - 1;
                                        long blockCol = 
tmp.getIndexes().getColumnIndex() - 1;
                                        return blockRow >= firstBlockRow && 
blockRow <= lastBlockRow && blockCol >= firstBlockCol &&
@@ -139,20 +136,23 @@ public class MatrixIndexingOOCInstruction extends 
IndexingOOCInstruction {
                                return;
                        }
 
-                       final BlockAligner<IndexedBlockMeta> aligner = new 
BlockAligner<>(ix, blocksize);
+                       final BlockAligner<MatrixIndexes> aligner = new 
BlockAligner<>(ix, blocksize);
+                       final ConcurrentHashMap<MatrixIndexes, Integer> 
consumptionCounts = new ConcurrentHashMap<>();
 
                        // We may need to construct our own intermediate stream 
to properly manage the cached items
                        boolean hasIntermediateStream = !qIn.hasStreamCache();
                        final CachingStream cachedStream = 
hasIntermediateStream ? new CachingStream(new SubscribableTaskQueue<>()) : 
qOut.getStreamCache();
                        cachedStream.activateIndexing();
+                       cachedStream.incrSubscriberCount(1); // We may require 
re-consumption of blocks (up to 4 times)
+                       final CompletableFuture<Void> future = new 
CompletableFuture<>();
 
-                       CompletableFuture<Void> future = 
filterOOC(qIn.getReadStream(), tmp -> {
+                       filterOOC(qIn.getReadStream(), tmp -> {
                                if (hasIntermediateStream) {
                                        // We write to an intermediate stream 
to ensure that these matrix blocks are properly cached
                                        
cachedStream.getWriteStream().enqueue(tmp);
                                }
 
-                               boolean completed = 
aligner.putNext(tmp.getIndexes(), new IndexedBlockMeta(tmp), (idx, sector) -> {
+                               boolean completed = 
aligner.putNext(tmp.getIndexes(), tmp.getIndexes(), (idx, sector) -> {
                                        int targetBlockRow = (int) 
(idx.getRowIndex() - 1);
                                        int targetBlockCol = (int) 
(idx.getColumnIndex() - 1);
 
@@ -176,18 +176,18 @@ public class MatrixIndexingOOCInstruction extends 
IndexingOOCInstruction {
 
                                        for(int r = 0; r < rowSegments; r++) {
                                                for(int c = 0; c < colSegments; 
c++) {
-                                                       IndexedBlockMeta ibm = 
sector.get(r, c);
-                                                       if(ibm == null)
+                                                       MatrixIndexes mIdx = 
sector.get(r, c);
+                                                       if(mIdx == null)
                                                                continue;
 
-                                                       IndexedMatrixValue mv = 
cachedStream.findCached(ibm.idx);
+                                                       IndexedMatrixValue mv = 
cachedStream.peekCached(mIdx);
                                                        MatrixBlock srcBlock = 
(MatrixBlock) mv.getValue();
 
                                                        if(target == null)
                                                                target = new 
MatrixBlock(nRows, nCols, srcBlock.isInSparseFormat());
 
-                                                       long srcBlockRowStart = 
(ibm.idx.getRowIndex() - 1) * blocksize;
-                                                       long srcBlockColStart = 
(ibm.idx.getColumnIndex() - 1) * blocksize;
+                                                       long srcBlockRowStart = 
(mIdx.getRowIndex() - 1) * blocksize;
+                                                       long srcBlockColStart = 
(mIdx.getColumnIndex() - 1) * blocksize;
                                                        long 
sliceRowStartGlobal = Math.max(targetRowStartGlobal, srcBlockRowStart);
                                                        long sliceRowEndGlobal 
= Math.min(targetRowEndGlobal,
                                                                
srcBlockRowStart + srcBlock.getNumRows() - 1);
@@ -205,21 +205,31 @@ public class MatrixIndexingOOCInstruction extends 
IndexingOOCInstruction {
 
                                                        MatrixBlock sliced = 
srcBlock.slice(sliceRowStart, sliceRowEnd, sliceColStart, sliceColEnd);
                                                        sliced.putInto(target, 
targetRowOffset, targetColOffset, true);
+                                                       final int 
maxConsumptions = aligner.getNumConsumptions(mIdx);
+
+                                                       Integer con = 
consumptionCounts.compute(mIdx, (k, v) -> {
+                                                               if (v == null)
+                                                                       v = 0;
+                                                               v = v+1;
+                                                               if (v == 
maxConsumptions)
+                                                                       return 
null;
+                                                               return v;
+                                                       });
+
+                                                       if (con == null)
+                                                               
cachedStream.incrProcessingCount(cachedStream.findCachedIndex(mIdx), 1);
                                                }
                                        }
 
                                        qOut.enqueue(new 
IndexedMatrixValue(idx, target));
                                });
 
-                               if(completed) {
-                                       // All blocks have been processed; we 
can cancel the future
-                                       // Currently, this does not affect 
processing (predicates prevent task submission anyway).
-                                       // However, a cancelled future may 
allow early file read aborts once implemented.
-                                       CompletableFuture<Void> f = 
futureRef.get();
-                                       if(f != null)
-                                               f.cancel(true);
-                               }
+                               if(completed)
+                                       future.complete(null);
                        }, tmp -> {
+                               if (future.isDone()) // Then we may skip blocks 
and avoid submitting tasks
+                                       return false;
+
                                // Pre-filter incoming blocks to avoid 
unnecessary task submission
                                long blockRow = tmp.getIndexes().getRowIndex() 
- 1;
                                long blockCol = 
tmp.getIndexes().getColumnIndex() - 1;
@@ -228,8 +238,15 @@ public class MatrixIndexingOOCInstruction extends 
IndexingOOCInstruction {
                        }, () -> {
                                aligner.close();
                                qOut.closeInput();
+                       }, tmp -> {
+                               // If elements are not processed in an existing 
caching stream, we increment the process counter to allow block deletion
+                               if (!hasIntermediateStream)
+                                       
cachedStream.incrProcessingCount(cachedStream.findCachedIndex(tmp.getIndexes()),
 1);
                        });
                        futureRef.set(future);
+
+                       if (hasIntermediateStream)
+                               cachedStream.scheduleDeletion(); // We can 
immediately delete blocks after consumption
                }
                //left indexing
                else if(opcode.equalsIgnoreCase(Opcodes.LEFT_INDEX.toString())) 
{
@@ -239,16 +256,4 @@ public class MatrixIndexingOOCInstruction extends 
IndexingOOCInstruction {
                        throw new DMLRuntimeException(
                                "Invalid opcode (" + opcode + ") encountered in 
MatrixIndexingOOCInstruction.");
        }
-
-       private static class IndexedBlockMeta {
-               public final MatrixIndexes idx;
-               ////public final long nrows;
-               //public final long ncols;
-
-               public IndexedBlockMeta(IndexedMatrixValue mv) {
-                       this.idx = mv.getIndexes();
-                       //this.nrows = mv.getValue().getNumRows();
-                       //this.ncols = mv.getValue().getNumColumns();
-               }
-       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java
index 099a26ebd9..235afce833 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCEvictionManager.java
@@ -106,7 +106,7 @@ import java.util.concurrent.locks.ReentrantLock;
 public class OOCEvictionManager {
 
        // Configuration: OOC buffer limit as percentage of heap
-       private static final double OOC_BUFFER_PERCENTAGE = 0.15 * 0.01 * 2; // 
15% of heap
+       private static final double OOC_BUFFER_PERCENTAGE = 0.15; // 15% of heap
 
        private static final double PARTITION_EVICTION_SIZE = 64 * 1024 * 1024; 
// 64 MB
 
@@ -192,6 +192,40 @@ public class OOCEvictionManager {
                LocalFileUtils.createLocalFileIfNotExist(_spillDir);
        }
 
+       public static void reset() {
+               TeeOOCInstruction.reset();
+               if (!_cache.isEmpty()) {
+                       System.err.println("There are dangling elements in the 
OOC Eviction cache: " + _cache.size());
+               }
+               _size.set(0);
+               _cache.clear();
+               _spillLocations.clear();
+               _partitions.clear();
+               _partitionCounter.set(0);
+               _streamPartitions.clear();
+       }
+
+       /**
+        * Removes a block from the cache without setting its data to null.
+        */
+       public static void forget(long streamId, int blockId) {
+               BlockEntry e;
+               synchronized (_cacheLock) {
+                       e = _cache.remove(streamId + "_" + blockId);
+               }
+
+               if (e != null) {
+                       e.lock.lock();
+                       try {
+                               if (e.state == BlockState.HOT)
+                                       _size.addAndGet(-e.size);
+                       } finally {
+                               e.lock.unlock();
+                       }
+                       System.out.println("Removed block " + streamId + "_" + 
blockId + " from cache (idx: " + (e.value != null ? e.value.getIndexes() : "?") 
+ ")");
+               }
+       }
+
        /**
         * Store a block in the OOC cache (serialize once)
         */
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index ca13cfdb2c..80cd6b6a87 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -131,11 +131,15 @@ public abstract class OOCInstruction extends Instruction {
                return new SubscribableTaskQueue<>();
        }
 
-       protected <T, R> CompletableFuture<Void> filterOOC(OOCStream<T> qIn, 
Consumer<T> processor, Function<T, Boolean> predicate, Runnable finalizer) {
+       protected <T> CompletableFuture<Void> filterOOC(OOCStream<T> qIn, 
Consumer<T> processor, Function<T, Boolean> predicate, Runnable finalizer) {
+               return filterOOC(qIn, processor, predicate, finalizer, null);
+       }
+
+       protected <T> CompletableFuture<Void> filterOOC(OOCStream<T> qIn, 
Consumer<T> processor, Function<T, Boolean> predicate, Runnable finalizer, 
Consumer<T> onNotProcessed) {
                if (_inQueues == null || _outQueues == null)
                        throw new NotImplementedException("filterOOC requires 
manual specification of all input and output streams for error propagation");
 
-               return submitOOCTasks(qIn, processor, finalizer, predicate);
+               return submitOOCTasks(qIn, processor, finalizer, predicate, 
onNotProcessed != null ? (i, tmp) -> onNotProcessed.accept(tmp) : null);
        }
 
        protected <T, R> CompletableFuture<Void> mapOOC(OOCStream<T> qIn, 
OOCStream<R> qOut, Function<T, R> mapper) {
@@ -163,10 +167,16 @@ public abstract class OOCInstruction extends Instruction {
                leftCache.activateIndexing();
                rightCache.activateIndexing();
 
+               if (!explicitLeftCaching)
+                       leftCache.incrSubscriberCount(1); // Prevent early 
block deletion as we may read elements twice
+
+               if (!explicitRightCaching)
+                       rightCache.incrSubscriberCount(1);
+
                Map<P, List<MatrixIndexes>> availableLeftInput = new 
ConcurrentHashMap<>();
                Map<P, BroadcastedElement> availableBroadcastInput = new 
ConcurrentHashMap<>();
 
-               return submitOOCTasks(List.of(qIn, broadcast), (i, tmp) -> {
+               CompletableFuture<Void> future = submitOOCTasks(List.of(qIn, 
broadcast), (i, tmp) -> {
                        P key = on.apply(tmp);
 
                        if (i == 0) { // qIn stream
@@ -184,11 +194,22 @@ public abstract class OOCInstruction extends Instruction {
                                                return v;
                                        });
                                } else {
+                                       if (!explicitLeftCaching)
+                                               
leftCache.incrProcessingCount(leftCache.findCachedIndex(tmp.getIndexes()), 1); 
// Correct for incremented subscriber count to allow block deletion
+
+                                       b.value = rightCache.peekCached(b.idx);
+
                                        // Directly emit
                                        qOut.enqueue(mapper.apply(tmp, b));
 
-                                       if (b.canRelease())
+                                       b.value = null;
+
+                                       if (b.canRelease()) {
                                                
availableBroadcastInput.remove(key);
+
+                                               if (!explicitRightCaching)
+                                                       
rightCache.incrProcessingCount(rightCache.findCachedIndex(b.idx), 1); // 
Correct for incremented subscriber count to allow block deletion
+                                       }
                                }
                        } else { // broadcast stream
                                if (explicitRightCaching)
@@ -201,16 +222,33 @@ public abstract class OOCInstruction extends Instruction {
 
                                if (queued != null) {
                                        for(MatrixIndexes idx : queued) {
-                                               b.value = 
rightCache.findCached(b.idx);
+                                               b.value = 
rightCache.peekCached(b.idx); // Only peek to prevent block deletion
                                                
qOut.enqueue(mapper.apply(leftCache.findCached(idx), b));
                                                b.value = null;
                                        }
                                }
 
-                               if (b.canRelease())
+                               if (b.canRelease()) {
                                        availableBroadcastInput.remove(key);
+
+                                       if (!explicitRightCaching)
+                                               
rightCache.incrProcessingCount(rightCache.findCachedIndex(tmp.getIndexes()), 
1); // Correct for incremented subscriber count to allow block deletion
+                               }
                        }
-               }, qOut::closeInput);
+               }, () -> {
+                       availableBroadcastInput.forEach((k, v) -> {
+                               
rightCache.incrProcessingCount(rightCache.findCachedIndex(v.idx), 1);
+                       });
+                       availableBroadcastInput.clear();
+                       qOut.closeInput();
+               });
+
+               if (explicitLeftCaching)
+                       leftCache.scheduleDeletion();
+               if (explicitRightCaching)
+                       rightCache.scheduleDeletion();
+
+               return future;
        }
 
        protected static class BroadcastedElement {
@@ -244,7 +282,7 @@ public abstract class OOCInstruction extends Instruction {
                public IndexedMatrixValue getValue() {
                        return value;
                }
-       };
+       }
 
        protected <T, R, P> CompletableFuture<Void> joinOOC(OOCStream<T> qIn1, 
OOCStream<T> qIn2, OOCStream<R> qOut, BiFunction<T, T, R> mapper, Function<T, 
P> on) {
                return joinOOC(qIn1, qIn2, qOut, mapper, on, on);
@@ -257,12 +295,18 @@ public abstract class OOCInstruction extends Instruction {
 
                final CompletableFuture<Void> future = new 
CompletableFuture<>();
 
+               boolean explicitLeftCaching = !qIn1.hasStreamCache();
+               boolean explicitRightCaching = !qIn2.hasStreamCache();
+
                // We need to construct our own stream to properly manage the 
cached items in the hash join
-               CachingStream leftCache = qIn1.hasStreamCache() ? 
qIn1.getStreamCache() : new 
CachingStream((SubscribableTaskQueue<IndexedMatrixValue>)qIn1); // We have to 
assume this generic type for now
-               CachingStream rightCache = qIn2.hasStreamCache() ? 
qIn2.getStreamCache() : new 
CachingStream((SubscribableTaskQueue<IndexedMatrixValue>)qIn2); // We have to 
assume this generic type for now
+               CachingStream leftCache = explicitLeftCaching ? new 
CachingStream((OOCStream<IndexedMatrixValue>) qIn1) : qIn1.getStreamCache();
+               CachingStream rightCache = explicitRightCaching ? new 
CachingStream((OOCStream<IndexedMatrixValue>) qIn2) : qIn2.getStreamCache();
                leftCache.activateIndexing();
                rightCache.activateIndexing();
 
+               leftCache.incrSubscriberCount(1);
+               rightCache.incrSubscriberCount(1);
+
                final OOCJoin<P, MatrixIndexes> join = new OOCJoin<>((idx, 
left, right) -> {
                        T leftObj = (T) leftCache.findCached(left);
                        T rightObj = (T) rightCache.findCached(right);
@@ -280,36 +324,40 @@ public abstract class OOCInstruction extends Instruction {
                        future.complete(null);
                });
 
+               if (explicitLeftCaching)
+                       leftCache.scheduleDeletion();
+               if (explicitRightCaching)
+                       rightCache.scheduleDeletion();
+
                return future;
        }
 
        protected <T> CompletableFuture<Void> submitOOCTasks(final 
List<OOCStream<T>> queues, BiConsumer<Integer, T> consumer, Runnable finalizer) 
{
+               return submitOOCTasks(queues, consumer, finalizer, null);
+       }
+
+       protected <T> CompletableFuture<Void> submitOOCTasks(final 
List<OOCStream<T>> queues, BiConsumer<Integer, T> consumer, Runnable finalizer, 
BiConsumer<Integer, T> onNotProcessed) {
                List<CompletableFuture<Void>> futures = new 
ArrayList<>(queues.size());
 
                for (int i = 0; i < queues.size(); i++)
                        futures.add(new CompletableFuture<>());
 
-               return submitOOCTasks(queues, consumer, finalizer, futures, 
null);
+               return submitOOCTasks(queues, consumer, finalizer, futures, 
null, onNotProcessed);
        }
 
-       protected <T> CompletableFuture<Void> submitOOCTasks(final 
List<OOCStream<T>> queues, BiConsumer<Integer, T> consumer, Runnable finalizer, 
List<CompletableFuture<Void>> futures, BiFunction<Integer, T, Boolean> 
predicate) {
+       protected <T> CompletableFuture<Void> submitOOCTasks(final 
List<OOCStream<T>> queues, BiConsumer<Integer, T> consumer, Runnable finalizer, 
List<CompletableFuture<Void>> futures, BiFunction<Integer, T, Boolean> 
predicate, BiConsumer<Integer, T> onNotProcessed) {
                addInStream(queues.toArray(OOCStream[]::new));
                ExecutorService pool = CommonThreadPool.get();
 
                final List<AtomicInteger> activeTaskCtrs = new 
ArrayList<>(queues.size());
-               final List<AtomicBoolean> streamsClosed = new 
ArrayList<>(queues.size());
 
-               for (int i = 0; i < queues.size(); i++) {
-                       activeTaskCtrs.add(new AtomicInteger(0));
-                       streamsClosed.add(new AtomicBoolean(false));
-               }
+               for (int i = 0; i < queues.size(); i++)
+                       activeTaskCtrs.add(new AtomicInteger(1));
 
-               final AtomicInteger globalTaskCtr = new AtomicInteger(0);
                final CompletableFuture<Void> globalFuture = 
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new));
                if (_outQueues == null)
                        _outQueues = Collections.emptySet();
                final Runnable oocFinalizer = oocTask(finalizer, null, 
Stream.concat(_outQueues.stream(), 
_inQueues.stream()).toArray(OOCStream[]::new));
-               final Object globalLock = new Object();
 
                int i = 0;
                @SuppressWarnings("unused")
@@ -319,84 +367,67 @@ public abstract class OOCInstruction extends Instruction {
                for (OOCStream<T> queue : queues) {
                        final int k = i;
                        final AtomicInteger localTaskCtr =  
activeTaskCtrs.get(k);
-                       final AtomicBoolean localStreamClosed = 
streamsClosed.get(k);
                        final CompletableFuture<Void> localFuture = 
futures.get(k);
+                       final AtomicBoolean closeRaceWatchdog = new 
AtomicBoolean(false);
 
                        //System.out.println("Substream (k " + k + ", id " + 
streamId + ", type '" + queue.getClass().getSimpleName() + "', stream_id " + 
queue.hashCode() + ")");
-                       queue.setSubscriber(oocTask(() -> {
-                               final T item = queue.dequeue();
+                       queue.setSubscriber(oocTask(callback -> {
+                               final T item = callback.get();
 
-                               if (predicate != null && item != null && 
!predicate.apply(k, item)) // Can get closed due to cancellation
-                                       return;
+                               if(item == null) {
+                                       
if(!closeRaceWatchdog.compareAndSet(false, true))
+                                               throw new 
DMLRuntimeException("Race condition observed: NO_MORE_TASKS callback has been 
triggered more than once");
 
-                               synchronized (globalLock) {
-                                       if (localFuture.isDone())
-                                               return;
+                                       if(localTaskCtr.decrementAndGet() == 0) 
{
+                                               // Then we can run the 
finalization procedure already
+                                               localFuture.complete(null);
+                                       }
+                                       return;
+                               }
 
-                                       globalTaskCtr.incrementAndGet();
+                               if(predicate != null && !predicate.apply(k, 
item)) { // Can get closed due to cancellation
+                                       if(onNotProcessed != null)
+                                               onNotProcessed.accept(k, item);
+                                       return;
                                }
 
-                               localTaskCtr.incrementAndGet();
+                               if(localFuture.isDone()) {
+                                       if(onNotProcessed != null)
+                                               onNotProcessed.accept(k, item);
+                                       return;
+                               }
+                               else {
+                                       localTaskCtr.incrementAndGet();
+                               }
 
                                pool.submit(oocTask(() -> {
-                                       if(item != null) {
-                                               //System.out.println("Accept" + 
((IndexedMatrixValue)item).getIndexes() + " (k " + k + ", id " + streamId + 
")");
-                                               consumer.accept(k, item);
-                                       }
-                                       else {
-                                               //System.out.println("Close 
substream (k " + k + ", id " + streamId + ")");
-                                               localStreamClosed.set(true);
-                                       }
-
-                                       boolean runFinalizer = false;
-
-                                       synchronized (globalLock) {
-                                               int localTasks = 
localTaskCtr.decrementAndGet();
-                                               boolean finalizeStream = 
localTasks == 0 && localStreamClosed.get();
-
-                                               int globalTasks = 
globalTaskCtr.get() - 1;
-
-                                               if (finalizeStream || 
(globalFuture.isDone() && localTasks == 0)) {
-                                                       
localFuture.complete(null);
+                                       // TODO For caching streams, we have no 
guarantee that item is still in memory -> NullPointer possible
+                                       consumer.accept(k, item);
 
-                                                       if 
(globalFuture.isDone() && globalTasks == 0)
-                                                               runFinalizer = 
true;
-                                               }
-
-                                               globalTaskCtr.decrementAndGet();
-                                       }
-
-                                       if (runFinalizer)
-                                               oocFinalizer.run();
+                                       if(localTaskCtr.decrementAndGet() == 0)
+                                               localFuture.complete(null);
                                }, localFuture, 
Stream.concat(_outQueues.stream(), 
_inQueues.stream()).toArray(OOCStream[]::new)));
+
+                               if(closeRaceWatchdog.get()) // Sanity check
+                                       throw new DMLRuntimeException("Race 
condition observed");
                        }, null,  Stream.concat(_outQueues.stream(), 
_inQueues.stream()).toArray(OOCStream[]::new)));
 
                        i++;
                }
 
-               pool.shutdown();
-
                globalFuture.whenComplete((res, e) -> {
-                       if (globalFuture.isCancelled() || 
globalFuture.isCompletedExceptionally())
+                       if (globalFuture.isCancelled() || 
globalFuture.isCompletedExceptionally()) {
                                futures.forEach(f -> {
-                                       if (!f.isDone()) {
-                                               if (globalFuture.isCancelled() 
|| globalFuture.isCompletedExceptionally())
+                                       if(!f.isDone()) {
+                                               if(globalFuture.isCancelled() 
|| globalFuture.isCompletedExceptionally())
                                                        f.cancel(true);
                                                else
                                                        f.complete(null);
                                        }
                                });
-
-                       boolean runFinalizer;
-
-                       synchronized (globalLock) {
-                               runFinalizer = globalTaskCtr.get() == 0;
                        }
 
-                       if (runFinalizer)
-                               oocFinalizer.run();
-
-                       //System.out.println("Shutdown (id " + streamId + ")");
+                       oocFinalizer.run();
                });
                return globalFuture;
        }
@@ -405,8 +436,8 @@ public abstract class OOCInstruction extends Instruction {
                return submitOOCTasks(List.of(queue), (i, tmp) -> 
consumer.accept(tmp), finalizer);
        }
 
-       protected <T> CompletableFuture<Void> submitOOCTasks(OOCStream<T> 
queue, Consumer<T> consumer, Runnable finalizer, Function<T, Boolean> 
predicate) {
-               return submitOOCTasks(List.of(queue), (i, tmp) -> 
consumer.accept(tmp), finalizer, List.of(new CompletableFuture<Void>()), (i, 
tmp) -> predicate.apply(tmp));
+       protected <T> CompletableFuture<Void> submitOOCTasks(OOCStream<T> 
queue, Consumer<T> consumer, Runnable finalizer, Function<T, Boolean> 
predicate, BiConsumer<Integer, T> onNotProcessed) {
+               return submitOOCTasks(List.of(queue), (i, tmp) -> 
consumer.accept(tmp), finalizer, List.of(new CompletableFuture<Void>()), (i, 
tmp) -> predicate.apply(tmp), onNotProcessed);
        }
 
        protected CompletableFuture<Void> submitOOCTask(Runnable r, 
OOCStream<?>... queues) {
@@ -450,6 +481,31 @@ public abstract class OOCInstruction extends Instruction {
                };
        }
 
+       private <T> Consumer<OOCStream.QueueCallback<T>> 
oocTask(Consumer<OOCStream.QueueCallback<T>> c, CompletableFuture<Void> future, 
 OOCStream<?>... queues) {
+               return callback -> {
+                       try {
+                               c.accept(callback);
+                       }
+                       catch (Exception ex) {
+                               DMLRuntimeException re = ex instanceof 
DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
+
+                               if (_failed) // Do avoid infinite cycles
+                                       throw re;
+
+                               _failed = true;
+
+                               for (OOCStream<?> q : queues)
+                                       q.propagateFailure(re);
+
+                               if (future != null)
+                                       future.completeExceptionally(re);
+
+                               // Rethrow to ensure proper future handling
+                               throw re;
+                       }
+               };
+       }
+
        /**
         * Tracks blocks and their counts to enable early emission
         * once all blocks for a given index are processed.
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java
index 1a12cb138b..f02c847e05 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java
@@ -22,6 +22,8 @@ package org.apache.sysds.runtime.instructions.ooc;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
 
+import java.util.function.Consumer;
+
 public interface OOCStream<T> extends OOCStreamable<T> {
        void enqueue(T t);
 
@@ -36,4 +38,28 @@ public interface OOCStream<T> extends OOCStreamable<T> {
        boolean hasStreamCache();
 
        CachingStream getStreamCache();
+
+       /**
+        * Registers a new subscriber that consumes the stream.
+        * While there is no guarantee for any specific order, the closing item 
LocalTaskQueue.NO_MORE_TASKS
+        * is guaranteed to be invoked after every other item has finished 
processing. Thus, the NO_MORE_TASKS
+        * callback can be used to free dependent resources and close output 
streams.
+        */
+       void setSubscriber(Consumer<QueueCallback<T>> subscriber);
+
+       class QueueCallback<T> {
+               private final T _result;
+               private final DMLRuntimeException _failure;
+
+               public QueueCallback(T result, DMLRuntimeException failure) {
+                       _result = result;
+                       _failure = failure;
+               }
+
+               public T get() {
+                       if (_failure != null)
+                               throw _failure;
+                       return _result;
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java
index bdc4086bdc..af2c0afa66 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java
@@ -25,6 +25,4 @@ public interface OOCStreamable<T> {
        OOCStream<T> getWriteStream();
 
        boolean isProcessed();
-
-       void setSubscriber(Runnable subscriber);
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java
new file mode 100644
index 0000000000..f56b0c46da
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCWatchdog.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Watchdog to help debug OOC streams/tasks that never close.
+ */
+public final class OOCWatchdog {
+       public static final boolean WATCH = false;
+       private static final ConcurrentHashMap<String, Entry> OPEN = new 
ConcurrentHashMap<>();
+       private static final ScheduledExecutorService EXEC =
+               Executors.newSingleThreadScheduledExecutor(r -> {
+                       Thread t = new Thread(r, "TemporaryWatchdog");
+                       t.setDaemon(true);
+                       return t;
+               });
+
+       private static final long STALE_MS = TimeUnit.SECONDS.toMillis(10);
+       private static final long SCAN_INTERVAL_MS = 
TimeUnit.SECONDS.toMillis(10);
+
+       static {
+               EXEC.scheduleAtFixedRate(OOCWatchdog::scan, SCAN_INTERVAL_MS, 
SCAN_INTERVAL_MS, TimeUnit.MILLISECONDS);
+       }
+
+       private OOCWatchdog() {
+               // no-op
+       }
+
+       public static void registerOpen(String id, String desc, String context, 
OOCStream<?> stream) {
+               OPEN.put(id, new Entry(desc, context, 
System.currentTimeMillis(), stream));
+       }
+
+       public static void addEvent(String id, String eventMsg) {
+               Entry e = OPEN.get(id);
+               if (e != null)
+                       e.events.add(eventMsg);
+       }
+
+       public static void registerClose(String id) {
+               OPEN.remove(id);
+       }
+
+       private static void scan() {
+               long now = System.currentTimeMillis();
+               for (Map.Entry<String, Entry> e : OPEN.entrySet()) {
+                       if (now - e.getValue().openedAt >= STALE_MS) {
+                               if (e.getValue().events.isEmpty())
+                                       continue; // Probably just a stream 
that has no consumer (remains to be checked why this can happen)
+                               System.err.println("[TemporaryWatchdog] Still 
open after " + (now - e.getValue().openedAt) + "ms: "
+                                       + e.getKey() + " (" + e.getValue().desc 
+ ")"
+                                       + (e.getValue().context != null ? " 
ctx=" + e.getValue().context : ""));
+                       }
+               }
+       }
+
+       private static class Entry {
+               final String desc;
+               final String context;
+               final long openedAt;
+               @SuppressWarnings("unused")
+               final OOCStream<?> stream;
+               ConcurrentLinkedQueue<String> events;
+
+               Entry(String desc, String context, long openedAt, OOCStream<?> 
stream) {
+                       this.desc = desc;
+                       this.context = context;
+                       this.openedAt = openedAt;
+                       this.stream = stream;
+                       this.events = new ConcurrentLinkedQueue<>();
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java
index e56d32e440..d70fc3ccb9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ParameterizedBuiltinOOCInstruction.java
@@ -20,7 +20,6 @@
 package org.apache.sysds.runtime.instructions.ooc;
 
 import org.apache.commons.lang3.NotImplementedException;
-import org.apache.commons.lang3.mutable.MutableObject;
 import org.apache.sysds.common.Opcodes;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -43,7 +42,6 @@ import 
org.apache.sysds.runtime.matrix.operators.SimpleOperator;
 import java.util.LinkedHashMap;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.atomic.AtomicBoolean;
 
 public class ParameterizedBuiltinOOCInstruction extends 
ComputationOOCInstruction {
 
@@ -110,29 +108,26 @@ public class ParameterizedBuiltinOOCInstruction extends 
ComputationOOCInstructio
 
                        Data finalPattern = pattern;
 
-                       AtomicBoolean found = new AtomicBoolean(false);
+                       addInStream(qIn);
+                       addOutStream(); // This instruction has no output stream
 
-                       MutableObject<CompletableFuture<Void>> futureRef = new 
MutableObject<>();
-                       CompletableFuture<Void> future = submitOOCTasks(qIn, 
tmp -> {
-                               boolean contains = 
((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue());
+                       CompletableFuture<Boolean> future = new 
CompletableFuture<>();
 
-                               if (contains) {
-                                       found.set(true);
+                       filterOOC(qIn, tmp -> {
+                               boolean contains = 
((MatrixBlock)tmp.getValue()).containsValue(((ScalarObject)finalPattern).getDoubleValue());
 
-                                       // Now we may complete the future
-                                       if (futureRef.getValue() != null)
-                                               
futureRef.getValue().complete(null);
-                               }
-                       }, () -> {});
-                       futureRef.setValue(future);
+                               if (contains)
+                                       future.complete(true);
+                       }, tmp -> !future.isDone(), // Don't start a separate 
worker if result already known
+                               () -> future.complete(false));     // Then the 
pattern was not found
 
+                       boolean ret;
                        try {
-                               futureRef.getValue().get();
+                               ret = future.get();
                        } catch (InterruptedException | ExecutionException e) {
                                throw new DMLRuntimeException(e);
                        }
 
-                       boolean ret = found.get();
                        ec.setScalarOutput(output.getName(), new 
BooleanObject(ret));
                }
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java
index 6edc4ecf27..5b996da0db 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java
@@ -23,13 +23,22 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
 import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Consumer;
+
 public class PlaybackStream implements OOCStream<IndexedMatrixValue>, 
OOCStreamable<IndexedMatrixValue> {
        private final CachingStream _streamCache;
-       private int _streamIdx;
+       private final AtomicInteger _streamIdx;
+       private final AtomicInteger _taskCtr;
+       private final AtomicBoolean _subscriberSet;
 
        public PlaybackStream(CachingStream streamCache) {
                this._streamCache = streamCache;
-               this._streamIdx = 0;
+               this._streamIdx = new AtomicInteger(0);
+               this._taskCtr = new AtomicInteger(1);
+               this._subscriberSet = new AtomicBoolean(false);
+               streamCache.incrSubscriberCount(1);
        }
 
        @Override
@@ -44,15 +53,29 @@ public class PlaybackStream implements 
OOCStream<IndexedMatrixValue>, OOCStreama
 
        @Override
        public LocalTaskQueue<IndexedMatrixValue> toLocalTaskQueue() {
-               final SubscribableTaskQueue<IndexedMatrixValue> q = new 
SubscribableTaskQueue<>();
-               setSubscriber(() -> q.enqueue(dequeue()));
+               final LocalTaskQueue<IndexedMatrixValue> q = new 
LocalTaskQueue<>();
+               setSubscriber(val -> {
+                       if (val.get() == null) {
+                               q.closeInput();
+                               return;
+                       }
+                       try {
+                               q.enqueueTask(val.get());
+                       }
+                       catch(InterruptedException e) {
+                               throw new RuntimeException(e);
+                       }
+               });
                return q;
        }
 
        @Override
-       public synchronized IndexedMatrixValue dequeue() {
+       public IndexedMatrixValue dequeue() {
+               if (_subscriberSet.get())
+                       throw new IllegalStateException("Cannot dequeue from a 
playback stream if a subscriber has been set");
+
                try {
-                       return _streamCache.get(_streamIdx++);
+                       return _streamCache.get(_streamIdx.getAndIncrement());
                } catch (InterruptedException e) {
                        throw new DMLRuntimeException(e);
                }
@@ -74,8 +97,35 @@ public class PlaybackStream implements 
OOCStream<IndexedMatrixValue>, OOCStreama
        }
 
        @Override
-       public void setSubscriber(Runnable subscriber) {
-               _streamCache.setSubscriber(subscriber);
+       public void setSubscriber(Consumer<QueueCallback<IndexedMatrixValue>> 
subscriber) {
+               if (!_subscriberSet.compareAndSet(false, true))
+                       throw new IllegalArgumentException("Subscriber cannot 
be set multiple times");
+
+               /**
+                * To guarantee that NO_MORE_TASKS is invoked after all 
subscriber calls
+                * finished, we keep track of running tasks using a task 
counter.
+                */
+               _streamCache.setSubscriber(() -> {
+                       try {
+                               _taskCtr.incrementAndGet();
+
+                               IndexedMatrixValue val;
+
+                               try {
+                                       val = 
_streamCache.get(_streamIdx.getAndIncrement());
+                               } catch (InterruptedException e) {
+                                       throw new DMLRuntimeException(e);
+                               }
+
+                               if (val != null)
+                                       subscriber.accept(new 
QueueCallback<>(val, null));
+
+                               if (_taskCtr.addAndGet(val == null ? -2 : -1) 
== 0)
+                                       subscriber.accept(new 
QueueCallback<>(null, null));
+                       } catch (DMLRuntimeException e) {
+                               subscriber.accept(new QueueCallback<>(null, e));
+                       }
+               }, false);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
index f136ffc2bb..d88cd1c5af 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
@@ -22,80 +22,173 @@ package org.apache.sysds.runtime.instructions.ooc;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
 
+import java.util.LinkedList;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Consumer;
+
 public class SubscribableTaskQueue<T> extends LocalTaskQueue<T> implements 
OOCStream<T> {
-       private Runnable _subscriber;
 
-       @Override
-       public synchronized void enqueue(T t) {
-               try {
-                       super.enqueueTask(t);
-               }
-               catch (InterruptedException e) {
-                       throw new DMLRuntimeException(e);
+       private final AtomicInteger _availableCtr = new AtomicInteger(1);
+       private final AtomicBoolean _closed = new AtomicBoolean(false);
+       private volatile Consumer<QueueCallback<T>> _subscriber = null;
+       private String _watchdogId;
+
+       public SubscribableTaskQueue() {
+               if (OOCWatchdog.WATCH) {
+                       _watchdogId = "STQ-" + hashCode();
+                       // Capture a short context to help identify origin
+                       OOCWatchdog.registerOpen(_watchdogId, 
"SubscribableTaskQueue@" + hashCode(), getCtxMsg(), this);
                }
+       }
 
-               if(_subscriber != null)
-                       _subscriber.run();
+       private String getCtxMsg() {
+               StackTraceElement[] st = new Exception().getStackTrace();
+               // Skip the first few frames (constructor, 
createWritableStream, etc.)
+               StringBuilder sb = new StringBuilder();
+               int limit = Math.min(st.length, 7);
+               for(int i = 2; i < limit; i++) {
+                       
sb.append(st[i].getClassName()).append(".").append(st[i].getMethodName()).append(":")
+                               .append(st[i].getLineNumber());
+                       if(i < limit - 1)
+                               sb.append(" <- ");
+               }
+               return sb.toString();
        }
 
        @Override
-       public T dequeue() {
-               try {
-                       return super.dequeueTask();
+       public void enqueue(T t) {
+               if (t == NO_MORE_TASKS)
+                       throw new DMLRuntimeException("Cannot enqueue 
NO_MORE_TASKS item");
+
+               int cnt = _availableCtr.incrementAndGet();
+
+               if (cnt <= 1) { // Then the queue was already closed and we 
disallow further enqueues
+                       _availableCtr.decrementAndGet(); // Undo increment
+                       throw new DMLRuntimeException("Cannot enqueue into 
closed SubscribableTaskQueue");
                }
-               catch (InterruptedException e) {
-                       throw new DMLRuntimeException(e);
+
+               Consumer<QueueCallback<T>> s = _subscriber;
+
+               if (s != null) {
+                       s.accept(new QueueCallback<>(t, _failure));
+                       onDeliveryFinished();
+                       return;
                }
+
+               synchronized (this) {
+                       // Re-check that subscriber is really null to avoid 
race conditions
+                       if (_subscriber == null) {
+                               try {
+                                       super.enqueueTask(t);
+                               }
+                               catch(InterruptedException e) {
+                                       throw new DMLRuntimeException(e);
+                               }
+                               return;
+                       }
+                       // Otherwise do not insert and re-schedule subscriber 
invocation
+                       s = _subscriber;
+               }
+
+               // Last case if due to race a subscriber has been set
+               s.accept(new QueueCallback<>(t, _failure));
+               onDeliveryFinished();
        }
 
        @Override
-       public synchronized void closeInput() {
-               super.closeInput();
-
-               if(_subscriber != null) {
-                       _subscriber.run();
-                       _subscriber = null;
-               }
+       public synchronized void enqueueTask(T t) {
+               enqueue(t);
        }
 
        @Override
-       public LocalTaskQueue<T> toLocalTaskQueue() {
-               return this;
+       public T dequeue() {
+               try {
+                       if (OOCWatchdog.WATCH)
+                               OOCWatchdog.addEvent(_watchdogId, "dequeue -- " 
+ getCtxMsg());
+                       T deq = super.dequeueTask();
+                       if (deq != NO_MORE_TASKS)
+                               onDeliveryFinished();
+                       return deq;
+               }
+               catch(InterruptedException e) {
+                       throw new DMLRuntimeException(e);
+               }
        }
 
        @Override
-       public OOCStream<T> getReadStream() {
-               return this;
+       public synchronized T dequeueTask() {
+               return dequeue();
        }
 
        @Override
-       public OOCStream<T> getWriteStream() {
-               return this;
+       public synchronized void closeInput() {
+               if (_closed.compareAndSet(false, true)) {
+                       super.closeInput();
+                       onDeliveryFinished();
+               } else {
+                       throw new IllegalStateException("Multiple close input 
calls");
+               }
        }
 
        @Override
-       public void setSubscriber(Runnable subscriber) {
-               int queueSize;
+       public void setSubscriber(Consumer<QueueCallback<T>> subscriber) {
+               if(subscriber == null)
+                       throw new IllegalArgumentException("Cannot set 
subscriber to null");
 
-               synchronized (this) {
+               LinkedList<T> data;
+
+               synchronized(this) {
                        if(_subscriber != null)
                                throw new DMLRuntimeException("Cannot set 
multiple subscribers");
-
                        _subscriber = subscriber;
-                       queueSize = _data.size();
-                       queueSize += _closedInput ? 1 : 0; // To trigger the 
NO_MORE_TASK element
+                       if(_failure != null)
+                               throw _failure;
+                       data = _data;
+                       _data = new LinkedList<>();
+               }
+
+               for (T t : data) {
+                       subscriber.accept(new QueueCallback<>(t, _failure));
+                       onDeliveryFinished();
                }
+       }
 
-               for (int i = 0; i < queueSize; i++)
-                       subscriber.run();
+       @SuppressWarnings("unchecked")
+       private void onDeliveryFinished() {
+               int ctr = _availableCtr.decrementAndGet();
+
+               if (ctr == 0) {
+                       Consumer<QueueCallback<T>> s = _subscriber;
+                       if (s != null)
+                               s.accept(new QueueCallback<>((T) 
LocalTaskQueue.NO_MORE_TASKS, _failure));
+
+                       if (OOCWatchdog.WATCH)
+                               OOCWatchdog.registerClose(_watchdogId);
+               }
        }
 
        @Override
        public synchronized void propagateFailure(DMLRuntimeException re) {
                super.propagateFailure(re);
+               Consumer<QueueCallback<T>> s = _subscriber;
+               if(s != null)
+                       s.accept(new QueueCallback<>(null, re));
+       }
+
+       @Override
+       public LocalTaskQueue<T> toLocalTaskQueue() {
+               return this;
+       }
+
+       @Override
+       public OOCStream<T> getReadStream() {
+               return this;
+       }
 
-               if(_subscriber != null)
-                       _subscriber.run();
+       @Override
+       public OOCStream<T> getWriteStream() {
+               return this;
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
index fd80b4e6e9..aba36297e7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
@@ -25,8 +25,37 @@ import 
org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 
+import java.util.concurrent.ConcurrentHashMap;
+
 public class TeeOOCInstruction extends ComputationOOCInstruction {
 
+       private static final ConcurrentHashMap<CachingStream, Integer> refCtr = 
new ConcurrentHashMap<>();
+
+       public static void reset() {
+               if (!refCtr.isEmpty()) {
+                       System.err.println("There are some dangling streams 
still in the cache: " + refCtr);
+                       refCtr.clear();
+               }
+       }
+
+       /**
+        * Increments the reference counter of a stream by the set amount.
+        */
+       public static void incrRef(OOCStreamable<IndexedMatrixValue> stream, 
int incr) {
+               if (!(stream instanceof CachingStream))
+                       return;
+
+               Integer ref = refCtr.compute((CachingStream)stream, (k, v) -> {
+                       if (v == null)
+                               v = 0;
+                       v += incr;
+                       return v <= 0 ? null : v;
+               });
+
+               if (ref == null)
+                       ((CachingStream)stream).scheduleDeletion();
+       }
+
        protected TeeOOCInstruction(OOCType type, CPOperand in1, CPOperand out, 
String opcode, String istr) {
                super(type, null, in1, out, opcode, istr);
        }
@@ -45,9 +74,20 @@ public class TeeOOCInstruction extends 
ComputationOOCInstruction {
                MatrixObject min = ec.getMatrixObject(input1);
                OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
 
+               CachingStream handle = qIn.hasStreamCache() ? 
qIn.getStreamCache() : new CachingStream(qIn);
+
+               if (!qIn.hasStreamCache()) {
+                       // We also set the input stream handle
+                       min.setStreamHandle(handle);
+                       incrRef(handle, 2);
+               }
+               else {
+                       incrRef(handle, 1);
+               }
+
                //get output and create new resettable stream
                MatrixObject mo = ec.getMatrixObject(output);
-               mo.setStreamHandle(new CachingStream(qIn));
+               mo.setStreamHandle(handle);
                mo.setMetaData(min.getMetaData());
        }
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java 
b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java
new file mode 100644
index 0000000000..e20b7ec426
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/PCATest.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class PCATest extends AutomatedTestBase {
+       private final static String TEST_NAME1 = "PCA";
+       private final static String TEST_DIR = "functions/ooc/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
PCATest.class.getSimpleName() + "/";
+       //private final static double eps = 1e-8;
+       private static final String INPUT_NAME_1 = "X";
+       private static final String OUTPUT_NAME_1 = "PC";
+       private static final String OUTPUT_NAME_2 = "V";
+
+       private final static int rows = 50000;
+       private final static int cols = 1000;
+       private final static int maxVal = 2;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               TestConfiguration config = new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+               addTestConfiguration(TEST_NAME1, config);
+       }
+
+       @Test
+       public void testPCA() {
+               boolean allow_opfusion = OptimizerUtils.ALLOW_OPERATOR_FUSION;
+               OptimizerUtils.ALLOW_OPERATOR_FUSION = false; // some fused ops 
are not implemented yet
+               runPCATest(16);
+               OptimizerUtils.ALLOW_OPERATOR_FUSION = allow_opfusion;
+       }
+
+       private void runPCATest(int k) {
+               Types.ExecMode platformOld = 
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME1);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+                       programArgs = new String[] {"-explain", "hops", 
"-stats", "-ooc", "-args", input(INPUT_NAME_1), Integer.toString(k), 
output(OUTPUT_NAME_1), output(OUTPUT_NAME_2)};
+
+                       // 1. Generate the data in-memory as MatrixBlock objects
+                       double[][] X_data = getRandomMatrix(rows, cols, 0, 
maxVal, 1, 7);
+
+                       // 2. Convert the double arrays to MatrixBlock objects
+                       MatrixBlock X_mb = 
DataConverter.convertToMatrixBlock(X_data);
+
+                       // 3. Create a binary matrix writer
+                       MatrixWriter writer = 
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+                       // 4. Write matrix A to a binary SequenceFile
+                       writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1), 
rows, cols, 1000, X_mb.getNonZeros());
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + 
".mtd"), Types.ValueType.FP64,
+                               new MatrixCharacteristics(rows, cols, 1000, 
X_mb.getNonZeros()), Types.FileFormat.BINARY);
+                       X_data = null;
+                       X_mb = null;
+
+                       runTest(true, false, null, -1);
+
+                       //check replace OOC op
+                       //Assert.assertTrue("OOC wasn't used for replacement",
+                       //      
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.REPLACE));
+
+                       //compare results
+
+                       // rerun without ooc flag
+                       programArgs = new String[] {"-explain", "hops", 
"-stats", "-args", input(INPUT_NAME_1), Integer.toString(k), 
output(OUTPUT_NAME_1 + "_target"), output(OUTPUT_NAME_2 + "_target")};
+                       runTest(true, false, null, -1);
+
+                       // compare matrices
+                       /*MatrixBlock ret1 = 
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_1),
+                               Types.FileFormat.BINARY, rows, cols, 1000);
+                       MatrixBlock ret2 = 
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_1 + "_target"),
+                               Types.FileFormat.BINARY, rows, cols, 1000);
+                       TestUtils.compareMatrices(ret1, ret2, eps);
+
+                       MatrixBlock ret2_1 = 
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_2),
+                               Types.FileFormat.BINARY, rows, cols, 1000);
+                       MatrixBlock ret2_2 = 
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME_2 + "_target"),
+                               Types.FileFormat.BINARY, rows, cols, 1000);
+                       TestUtils.compareMatrices(ret2_1, ret2_2, eps);*/
+               }
+               catch(IOException e) {
+                       throw new RuntimeException(e);
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java 
b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
index 4365dd629f..6efa72d6ae 100644
--- a/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
+++ b/src/test/java/org/apache/sysds/test/usertest/pythonapi/StartupTest.java
@@ -304,6 +304,8 @@ public class StartupTest {
                script.startWritingColToPipe(0, null, 0);
        }
        private static class ExitCalled extends RuntimeException implements 
PythonDMLScript.ExitHandler {
+               private static final long serialVersionUID = 
-4247240099965056602L;
+
                @Override
                public void exit(int status) {
                        throw this;
diff --git a/src/test/scripts/functions/ooc/PCA.dml 
b/src/test/scripts/functions/ooc/PCA.dml
new file mode 100644
index 0000000000..567d701ec0
--- /dev/null
+++ b/src/test/scripts/functions/ooc/PCA.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1);
+k = $2;
+
+[PC, V] = pca(X=X, K=k)
+
+write(PC, $3, format="binary");
+write(V, $4, format="binary");

Reply via email to