This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 6514887778 [SYSTEMDS-3918] New out-of-core queues and primitives
6514887778 is described below
commit 6514887778059b92f39eeaf29a22ac02eb24bdd8
Author: Jannik Lindemann <[email protected]>
AuthorDate: Sun Nov 9 09:22:47 2025 +0100
[SYSTEMDS-3918] New out-of-core queues and primitives
Closes #2347.
---
src/main/java/org/apache/sysds/hops/BinaryOp.java | 2 +-
.../controlprogram/caching/CacheableData.java | 33 ++-
.../controlprogram/caching/MatrixObject.java | 2 +-
.../controlprogram/parfor/LocalTaskQueue.java | 24 +--
.../ooc/AggregateUnaryOOCInstruction.java | 10 +-
.../instructions/ooc/BinaryOOCInstruction.java | 63 +++---
.../runtime/instructions/ooc/CachingStream.java | 166 +++++++++++++++
.../ooc/CentralMomentOOCInstruction.java | 86 ++------
.../instructions/ooc/CtableOOCInstruction.java | 10 +-
.../ooc/MatrixVectorBinaryOOCInstruction.java | 10 +-
.../runtime/instructions/ooc/OOCInstruction.java | 236 ++++++++++++++++++++-
.../sysds/runtime/instructions/ooc/OOCStream.java | 39 ++++
.../runtime/instructions/ooc/OOCStreamable.java | 30 +++
.../runtime/instructions/ooc/PlaybackStream.java | 95 +++++++++
.../instructions/ooc/ReblockOOCInstruction.java | 7 +-
.../runtime/instructions/ooc/ResettableStream.java | 116 ----------
.../instructions/ooc/SubscribableTaskQueue.java | 110 ++++++++++
.../instructions/ooc/TSMMOOCInstruction.java | 4 +-
.../instructions/ooc/TeeOOCInstruction.java | 5 +-
.../instructions/ooc/TransposeOOCInstruction.java | 29 +--
.../instructions/ooc/UnaryOOCInstruction.java | 28 +--
.../sysds/runtime/matrix/operators/CMOperator.java | 7 +
.../org/apache/sysds/runtime/util/OOCJoin.java | 69 ++++++
.../test/functions/ooc/BinaryMatrixMatrixTest.java | 136 ++++++++++++
.../test/functions/ooc/BinaryMatrixScalarTest.java | 122 +++++++++++
.../scripts/functions/ooc/BinaryMatrixMatrix.dml | 29 +++
.../scripts/functions/ooc/BinaryMatrixScalar.dml | 28 +++
27 files changed, 1178 insertions(+), 318 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index a3ddb45ea6..2b803a053c 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -478,7 +478,7 @@ public class BinaryOp extends MultiThreadedHop {
setLineNumbers(softmax);
setLops(softmax);
}
- else if ( et == ExecType.CP || et == ExecType.GPU || et
== ExecType.FED )
+ else if ( et == ExecType.CP || et == ExecType.GPU || et
== ExecType.FED || et == ExecType.OOC )
{
Lop binary = null;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 7457e56ba5..34a8aa1863 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -49,7 +49,9 @@ import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
-import org.apache.sysds.runtime.instructions.ooc.ResettableStream;
+import org.apache.sysds.runtime.instructions.ooc.OOCStream;
+import org.apache.sysds.runtime.instructions.ooc.OOCStreamable;
+import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue;
import org.apache.sysds.runtime.instructions.spark.data.BroadcastObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
@@ -223,7 +225,7 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
private BroadcastObject<T> _bcHandle = null; //Broadcast handle
protected HashMap<GPUContext, GPUObject> _gpuObjects = null; //Per
GPUContext object allocated on GPU
//TODO generalize for frames
- private LocalTaskQueue<IndexedMatrixValue> _streamHandle = null;
+ private OOCStreamable<IndexedMatrixValue> _streamHandle = null;
private LineageItem _lineage = null;
@@ -469,34 +471,25 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
return _bcHandle != null && _bcHandle.hasBackReference();
}
- public LocalTaskQueue<IndexedMatrixValue> getStreamHandle() {
+ public OOCStream<IndexedMatrixValue> getStreamHandle() {
if( !hasStreamHandle() ) {
- _streamHandle = new LocalTaskQueue<>();
+ final SubscribableTaskQueue<IndexedMatrixValue>
_mStream = new SubscribableTaskQueue<>();
+ _streamHandle = _mStream;
DataCharacteristics dc = getDataCharacteristics();
MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
LongStream.range(0, dc.getNumBlocks())
.mapToObj(i ->
UtilFunctions.createIndexedMatrixBlock(src, dc, i))
.forEach( blk -> {
try{
- _streamHandle.enqueueTask(blk);
+ _mStream.enqueue(blk);
}
catch(Exception ex) {
- throw new
DMLRuntimeException(ex);
+ throw ex instanceof
DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
}});
- _streamHandle.closeInput();
- }
- else if(_streamHandle != null && _streamHandle.isProcessed()
- && _streamHandle instanceof ResettableStream)
- {
- try {
- ((ResettableStream)_streamHandle).reset();
- }
- catch(Exception ex) {
- throw new DMLRuntimeException(ex);
- }
+ _mStream.closeInput();
}
- return _streamHandle;
+ return _streamHandle.getReadStream();
}
/**
@@ -539,7 +532,7 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
_gpuObjects.remove(gCtx);
}
- public synchronized void
setStreamHandle(LocalTaskQueue<IndexedMatrixValue> q) {
+ public synchronized void
setStreamHandle(OOCStreamable<IndexedMatrixValue> q) {
_streamHandle = q;
}
@@ -633,7 +626,7 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
_requiresLocalWrite = false;
}
else if( hasStreamHandle() ) {
- _data = readBlobFromStream(
getStreamHandle() );
+ _data = readBlobFromStream(
getStreamHandle().toLocalTaskQueue() );
}
else if( getRDDHandle()==null ||
getRDDHandle().allowsShortCircuitRead() ) {
if( DMLScript.STATISTICS )
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 9f4ca12dd7..496bca8764 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -611,7 +611,7 @@ public class MatrixObject extends
CacheableData<MatrixBlock> {
MetaDataFormat iimd = (MetaDataFormat) _metaData;
FileFormat fmt = (ofmt != null ? FileFormat.safeValueOf(ofmt) :
iimd.getFileFormat());
MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(fmt, rep, fprop);
- return writer.writeMatrixFromStream(fname, getStreamHandle(),
+ return writer.writeMatrixFromStream(fname,
getStreamHandle().toLocalTaskQueue(),
getNumRows(), getNumColumns(),
ConfigurationManager.getBlocksize());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
index 350fc8de3b..783981e0f1 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
@@ -43,8 +43,8 @@ public class LocalTaskQueue<T>
public static final int MAX_SIZE = 100000; //main memory
constraint
public static final Object NO_MORE_TASKS = null; //object to signal
NO_MORE_TASKS
- private LinkedList<T> _data = null;
- private boolean _closedInput = false;
+ protected LinkedList<T> _data = null;
+ protected boolean _closedInput = false;
private DMLRuntimeException _failure = null;
private static final Log LOG =
LogFactory.getLog(LocalTaskQueue.class.getName());
@@ -60,21 +60,19 @@ public class LocalTaskQueue<T>
* @param t task
* @throws InterruptedException if InterruptedException occurs
*/
- public synchronized void enqueueTask( T t )
+ public synchronized void enqueueTask( T t )
throws InterruptedException
{
- while( _data.size() + 1 > MAX_SIZE && _failure == null )
- {
+ while(_data.size() + 1 > MAX_SIZE && _failure == null) {
LOG.warn("MAX_SIZE of task queue reached.");
wait(); //max constraint reached, wait for read
}
- if ( _failure != null )
+ if(_failure != null)
throw _failure;
-
- _data.addLast( t );
-
- notify(); //notify waiting readers
+
+ _data.addLast(t);
+ notify();
}
/**
@@ -97,14 +95,14 @@ public class LocalTaskQueue<T>
if ( _failure != null )
throw _failure;
-
+
T t = _data.removeFirst();
notify(); // notify waiting writers
return t;
}
-
+
/**
* Synchronized (logical) insert of a NO_MORE_TASKS symbol at the end
of the FIFO queue in order to
* mark that no more tasks will be inserted into the queue.
@@ -112,7 +110,7 @@ public class LocalTaskQueue<T>
public synchronized void closeInput()
{
_closedInput = true;
- notifyAll(); //notify all waiting readers
+ notifyAll();
}
public synchronized boolean isProcessed() {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
index c87b3c99cf..2a53c5400a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
@@ -76,7 +76,7 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
//setup operators and input queue
AggregateUnaryOperator aggun = (AggregateUnaryOperator)
getOperator();
MatrixObject min = ec.getMatrixObject(input1);
- LocalTaskQueue<IndexedMatrixValue> q = min.getStreamHandle();
+ OOCStream<IndexedMatrixValue> q = min.getStreamHandle();
int blen = ConfigurationManager.getBlocksize();
if (aggun.isRowAggregate() || aggun.isColAggregate()) {
@@ -86,13 +86,13 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
OOCMatrixBlockTracker aggTracker = new
OOCMatrixBlockTracker(emitThreshold);
HashMap<Long, MatrixBlock> corrs = new HashMap<>(); //
correction blocks
- LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
+ OOCStream<IndexedMatrixValue> qOut =
createWritableStream();
ec.getMatrixObject(output).setStreamHandle(qOut);
submitOOCTask(() -> {
IndexedMatrixValue tmp = null;
try {
- while((tmp = q.dequeueTask())
!= LocalTaskQueue.NO_MORE_TASKS) {
+ while((tmp = q.dequeue()) !=
LocalTaskQueue.NO_MORE_TASKS) {
long idx =
aggun.isRowAggregate() ?
tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex();
MatrixBlock ret =
aggTracker.get(idx);
@@ -139,7 +139,7 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
new
MatrixIndexes(1, tmp.getIndexes().getColumnIndex());
IndexedMatrixValue
tmpOut = new IndexedMatrixValue(midx, ret);
-
qOut.enqueueTask(tmpOut);
+ qOut.enqueue(tmpOut);
// drop intermediate
states
aggTracker.remove(idx);
corrs.remove(idx);
@@ -159,7 +159,7 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
MatrixBlock ret = new MatrixBlock(1,1+extra,false);
MatrixBlock corr = new MatrixBlock(1,1+extra,false);
try {
- while((tmp = q.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ while((tmp = q.dequeue()) !=
LocalTaskQueue.NO_MORE_TASKS) {
//block aggregation
MatrixBlock ltmp = (MatrixBlock)
((MatrixBlock) tmp.getValue())
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
index 1dfc99be81..148592b6a9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
@@ -19,16 +19,14 @@
package org.apache.sysds.runtime.instructions.ooc;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
@@ -54,33 +52,46 @@ public class BinaryOOCInstruction extends
ComputationOOCInstruction {
@Override
public void processInstruction( ExecutionContext ec ) {
- //TODO support all types, currently only binary matrix-scalar
-
+ if (input1.isMatrix() && input2.isMatrix())
+ processMatrixMatrixInstruction(ec);
+ else
+ processScalarMatrixInstruction(ec);
+ }
+
+ protected void processMatrixMatrixInstruction(ExecutionContext ec) {
+ MatrixObject m1 = ec.getMatrixObject(input1);
+ MatrixObject m2 = ec.getMatrixObject(input2);
+
+ OOCStream<IndexedMatrixValue> qIn1 = m1.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qIn2 = m2.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qOut = new
SubscribableTaskQueue<>();
+ ec.getMatrixObject(output).setStreamHandle(qOut);
+
+ joinOOC(qIn1, qIn2, qOut, (tmp1, tmp2) -> {
+ IndexedMatrixValue tmpOut = new IndexedMatrixValue();
+ tmpOut.set(tmp1.getIndexes(),
+
tmp1.getValue().binaryOperations((BinaryOperator)_optr, tmp2.getValue(),
tmpOut.getValue()));
+ return tmpOut;
+ }, IndexedMatrixValue::getIndexes);
+ }
+
+ protected void processScalarMatrixInstruction(ExecutionContext ec) {
//get operator and scalar
- CPOperand scalar = ( input1.getDataType() == DataType.MATRIX )
? input2 : input1;
+ CPOperand scalar = input1.isMatrix() ? input2 : input1;
ScalarObject constant = ec.getScalarInput(scalar);
ScalarOperator sc_op =
((ScalarOperator)_optr).setConstant(constant.getDoubleValue());
-
+
//create thread and process binary operation
- MatrixObject min = ec.getMatrixObject(input1);
- LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
- LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
+ MatrixObject min = ec.getMatrixObject(input1.isMatrix() ?
input1 : input2);
+ OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qOut = createWritableStream();
ec.getMatrixObject(output).setStreamHandle(qOut);
-
- submitOOCTask(() -> {
- IndexedMatrixValue tmp = null;
- try {
- while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
- IndexedMatrixValue tmpOut = new
IndexedMatrixValue();
- tmpOut.set(tmp.getIndexes(),
-
tmp.getValue().scalarOperations(sc_op, new MatrixBlock()));
- qOut.enqueueTask(tmpOut);
- }
- qOut.closeInput();
- }
- catch(Exception ex) {
- throw new DMLRuntimeException(ex);
- }
- }, qIn, qOut);
+
+ mapOOC(qIn, qOut, tmp -> {
+ IndexedMatrixValue tmpOut = new IndexedMatrixValue();
+ tmpOut.set(tmp.getIndexes(),
+ tmp.getValue().scalarOperations(sc_op, new
MatrixBlock()));
+ return tmpOut;
+ });
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
new file mode 100644
index 0000000000..1a54030280
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A wrapper around LocalTaskQueue to consume the source stream and reset to
+ * consume again for other operators.
+ *
+ */
+public class CachingStream implements OOCStreamable<IndexedMatrixValue> {
+
+ public static final IDSequence _streamSeq = new IDSequence();
+
+ // original live stream
+ private final OOCStream<IndexedMatrixValue> _source;
+
+ // stream identifier
+ private final long _streamId;
+
+ // block counter
+ private int _numBlocks = 0;
+
+ private Runnable[] _subscribers;
+
+ // state flags
+ private boolean _cacheInProgress = true; // caching in progress, in the
first pass.
+ private Map<MatrixIndexes, Integer> _index;
+
+ public CachingStream(OOCStream<IndexedMatrixValue> source) {
+ this(source, _streamSeq.getNextID());
+ }
+
+ public CachingStream(OOCStream<IndexedMatrixValue> source, long
streamId) {
+ _source = source;
+ _streamId = streamId;
+ source.setSubscriber(() -> {
+ try {
+ boolean closed = fetchFromStream();
+ Runnable[] mSubscribers = _subscribers;
+
+ if(mSubscribers != null) {
+ for(Runnable mSubscriber : mSubscribers)
+ mSubscriber.run();
+
+ if (closed) {
+ synchronized (this) {
+ _subscribers = null;
+ }
+ }
+ }
+ } catch (InterruptedException e) {
+ throw new DMLRuntimeException(e);
+ }
+ });
+ }
+
+ private boolean fetchFromStream() throws InterruptedException {
+ synchronized (this) {
+ if(!_cacheInProgress)
+ throw new DMLRuntimeException("Stream is
closed");
+ }
+
+ IndexedMatrixValue task = _source.dequeue();
+
+ synchronized (this) {
+ if(task != LocalTaskQueue.NO_MORE_TASKS) {
+ OOCEvictionManager.put(_streamId, _numBlocks,
task);
+ if (_index != null)
+ _index.put(task.getIndexes(),
_numBlocks);
+ _numBlocks++;
+ notifyAll();
+ return false;
+ }
+ else {
+ _cacheInProgress = false; // caching is complete
+ notifyAll();
+ return true;
+ }
+ }
+ }
+
+ public synchronized IndexedMatrixValue get(int idx) throws
InterruptedException {
+ while (true) {
+ if (idx < _numBlocks)
+ return OOCEvictionManager.get(_streamId, idx);
+ else if (!_cacheInProgress)
+ return
(IndexedMatrixValue)LocalTaskQueue.NO_MORE_TASKS;
+
+ wait();
+ }
+ }
+
+ public synchronized IndexedMatrixValue findCached(MatrixIndexes idx) {
+ return OOCEvictionManager.get(_streamId, _index.get(idx));
+ }
+
+ public synchronized void activateIndexing() {
+ if (_index == null)
+ _index = new HashMap<>();
+ }
+
+ @Override
+ public OOCStream<IndexedMatrixValue> getReadStream() {
+ return new PlaybackStream(this);
+ }
+
+ @Override
+ public OOCStream<IndexedMatrixValue> getWriteStream() {
+ return _source.getWriteStream();
+ }
+
+ @Override
+ public boolean isProcessed() {
+ return false;
+ }
+
+ @Override
+ public void setSubscriber(Runnable subscriber) {
+ int mNumBlocks;
+ synchronized (this) {
+ mNumBlocks = _numBlocks;
+ if (_cacheInProgress) {
+ int newLen = _subscribers == null ? 1 :
_subscribers.length + 1;
+ Runnable[] newSubscribers = new
Runnable[newLen];
+
+ if(newLen > 1)
+ System.arraycopy(_subscribers, 0,
newSubscribers, 0, newLen - 1);
+
+ newSubscribers[newLen - 1] = subscriber;
+ _subscribers = newSubscribers;
+ }
+ }
+
+ for (int i = 0; i < mNumBlocks; i++)
+ subscriber.run();
+
+ if (!_cacheInProgress)
+ subscriber.run(); // To fetch the NO_MORE_TASK element
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java
index 9c122662c2..7b3346ab6d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java
@@ -30,17 +30,9 @@ import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-
public class CentralMomentOOCInstruction extends AggregateUnaryOOCInstruction {
private CentralMomentOOCInstruction(CMOperator cm, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out,
@@ -70,7 +62,7 @@ public class CentralMomentOOCInstruction extends
AggregateUnaryOOCInstruction {
*/
MatrixObject matObj = ec.getMatrixObject(input1.getName());
- LocalTaskQueue<IndexedMatrixValue> qIn =
matObj.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qIn = matObj.getStreamHandle();
CPOperand scalarInput = (input3 == null ? input2 : input3);
ScalarObject order = ec.getScalarInput(scalarInput);
@@ -81,20 +73,10 @@ public class CentralMomentOOCInstruction extends
AggregateUnaryOOCInstruction {
CMOperator finalCm_op = cm_op;
- List<CM_COV_Object> cmObjs = new ArrayList<>();
+ OOCStream<CM_COV_Object> cmObjs = createWritableStream();
if(input3 == null) {
- try {
- IndexedMatrixValue tmp;
-
- while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
- // We only handle MatrixBlock, other
types of MatrixValue will fail here
- cmObjs.add(((MatrixBlock)
tmp.getValue()).cmOperations(cm_op));
- }
- }
- catch(Exception ex) {
- throw new DMLRuntimeException(ex);
- }
+ mapOOC(qIn, cmObjs, tmp -> ((MatrixBlock)
tmp.getValue()).cmOperations(new CMOperator(finalCm_op))); // Need to copy
CMOperator as its ValueFunction is stateful
}
else {
// Here we use a hash join approach
@@ -107,59 +89,23 @@ public class CentralMomentOOCInstruction extends
AggregateUnaryOOCInstruction {
if (dc.getBlocksize() != dcW.getBlocksize())
throw new DMLRuntimeException("Different block
sizes are not yet supported");
- LocalTaskQueue<IndexedMatrixValue> wIn =
wtObj.getStreamHandle();
-
- try {
- IndexedMatrixValue tmp = qIn.dequeueTask();
- IndexedMatrixValue tmpW = wIn.dequeueTask();
- Map<MatrixIndexes, MatrixValue> left = new
HashMap<>();
- Map<MatrixIndexes, MatrixValue> right = new
HashMap<>();
-
- boolean cont = tmp !=
LocalTaskQueue.NO_MORE_TASKS || tmpW != LocalTaskQueue.NO_MORE_TASKS;
-
- while(cont) {
- cont = false;
-
- if(tmp != LocalTaskQueue.NO_MORE_TASKS)
{
- MatrixValue weights =
right.remove(tmp.getIndexes());
-
- if(weights != null)
-
cmObjs.add(((MatrixBlock) tmp.getValue()).cmOperations(cm_op, (MatrixBlock)
weights));
- else
-
left.put(tmp.getIndexes(), tmp.getValue());
-
- tmp = qIn.dequeueTask();
- cont = tmp !=
LocalTaskQueue.NO_MORE_TASKS;
- }
+ OOCStream<IndexedMatrixValue> wIn =
wtObj.getStreamHandle();
- if(tmpW !=
LocalTaskQueue.NO_MORE_TASKS) {
- MatrixValue q =
left.remove(tmpW.getIndexes());
-
- if(q != null)
-
cmObjs.add(((MatrixBlock) q).cmOperations(cm_op, (MatrixBlock)
tmpW.getValue()));
- else
-
right.put(tmpW.getIndexes(), tmpW.getValue());
-
- tmpW = wIn.dequeueTask();
- cont |= tmpW !=
LocalTaskQueue.NO_MORE_TASKS;
- }
- }
-
- if (!left.isEmpty() || !right.isEmpty())
- throw new
DMLRuntimeException("Unmatched blocks: values=" + left.size() + ", weights=" +
right.size());
- }
- catch(Exception ex) {
- throw new DMLRuntimeException(ex);
- }
+ joinOOC(qIn, wIn, cmObjs,
+ (tmp, weights) ->
+ ((MatrixBlock)
tmp.getValue()).cmOperations(new CMOperator(finalCm_op), (MatrixBlock)
weights.getValue()),
+ IndexedMatrixValue::getIndexes);
}
- Optional<CM_COV_Object> res = cmObjs.stream()
- .reduce((arg0, arg1) -> (CM_COV_Object)
finalCm_op.fn.execute(arg0, arg1));
-
try {
- ec.setScalarOutput(output_name, new
DoubleObject(res.get().getRequiredResult(finalCm_op)));
- }
- catch(Exception ex) {
+ CM_COV_Object agg = cmObjs.dequeue();
+ CM_COV_Object next;
+
+ while ((next = cmObjs.dequeue()) !=
LocalTaskQueue.NO_MORE_TASKS)
+ agg = (CM_COV_Object)
finalCm_op.fn.execute(agg, next);
+
+ ec.setScalarOutput(output_name, new
DoubleObject(agg.getRequiredResult(finalCm_op)));
+ } catch (Exception ex) {
throw new DMLRuntimeException(ex);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java
index c4c668ab6b..01fd348d10 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CtableOOCInstruction.java
@@ -79,7 +79,7 @@ public class CtableOOCInstruction extends
ComputationOOCInstruction {
public void processInstruction( ExecutionContext ec ) {
MatrixObject in1 = ec.getMatrixObject(input1); // stream
- LocalTaskQueue<IndexedMatrixValue> qIn1 = in1.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qIn1 = in1.getStreamHandle();
IndexedMatrixValue tmp1 = null;
long outputDim1 = ec.getScalarInput(_outDim1).getLongValue();
@@ -90,7 +90,7 @@ public class CtableOOCInstruction extends
ComputationOOCInstruction {
Ctable.OperationTypes ctableOp = findCtableOperation();
MatrixObject in2 = null, in3 = null;
- LocalTaskQueue<IndexedMatrixValue> qIn2 = null, qIn3 = null;
+ OOCStream<IndexedMatrixValue> qIn2 = null, qIn3 = null;
double cst2 = 0, cst3 = 0;
// init vars based on ctableOp
@@ -121,7 +121,7 @@ public class CtableOOCInstruction extends
ComputationOOCInstruction {
}
try {
- while((tmp1 = qIn1.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ while((tmp1 = qIn1.dequeue()) !=
LocalTaskQueue.NO_MORE_TASKS) {
MatrixBlock block1 = (MatrixBlock)
tmp1.getValue();
long r = tmp1.getIndexes().getRowIndex();
@@ -172,13 +172,13 @@ public class CtableOOCInstruction extends
ComputationOOCInstruction {
}
private MatrixBlock getOrDequeueBlock(long key, long cols,
HashMap<Long, MatrixBlock> blocks,
- LocalTaskQueue<IndexedMatrixValue> queue) throws
InterruptedException
+ OOCStream<IndexedMatrixValue> queue) throws InterruptedException
{
MatrixBlock block = blocks.get(key);
if (block == null) {
IndexedMatrixValue tmp;
// corresponding block still in queue, dequeue until
found
- while ((tmp = queue.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ while ((tmp = queue.dequeue()) !=
LocalTaskQueue.NO_MORE_TASKS) {
block = (MatrixBlock) tmp.getValue();
long r = tmp.getIndexes().getRowIndex();
long c = tmp.getIndexes().getColumnIndex();
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
index aa215e83e9..38586428e1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
@@ -83,15 +83,15 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
long emitThreshold =
min.getDataCharacteristics().getNumColBlocks();
OOCMatrixBlockTracker aggTracker = new
OOCMatrixBlockTracker(emitThreshold);
- LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
- LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
+ OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qOut = createWritableStream();
BinaryOperator plus =
InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
ec.getMatrixObject(output).setStreamHandle(qOut);
submitOOCTask(() -> {
IndexedMatrixValue tmp = null;
try {
- while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ while((tmp = qIn.dequeue()) !=
LocalTaskQueue.NO_MORE_TASKS) {
MatrixBlock matrixBlock =
(MatrixBlock) tmp.getValue();
long rowIndex =
tmp.getIndexes().getRowIndex();
long colIndex =
tmp.getIndexes().getColumnIndex();
@@ -103,7 +103,7 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
// for single column block, no
aggregation neeeded
if(emitThreshold == 1) {
- qOut.enqueueTask(new
IndexedMatrixValue(tmp.getIndexes(), partialResult));
+ qOut.enqueue(new
IndexedMatrixValue(tmp.getIndexes(), partialResult));
}
else {
// aggregation
@@ -116,7 +116,7 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
if
(aggTracker.putAndIncrementCount(rowIndex, currAgg)){
//
early block output: emit aggregated block
MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L);
-
qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg));
+
qOut.enqueue(new IndexedMatrixValue(idx, currAgg));
aggTracker.remove(rowIndex);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index 0d15949289..1fdd5cd965 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -19,29 +19,48 @@
package org.apache.sysds.runtime.instructions.ooc;
+import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.apache.sysds.runtime.util.OOCJoin;
+import java.util.ArrayList;
import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import java.util.stream.Stream;
public abstract class OOCInstruction extends Instruction {
protected static final Log LOG =
LogFactory.getLog(OOCInstruction.class.getName());
+ private static final AtomicInteger nextStreamId = new AtomicInteger(0);
public enum OOCType {
- Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary,
MAPMM, MMTSJ, Reorg, CM, Ctable
+ Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary,
MAPMM, MMTSJ, Reorg, CM, Ctable, MatrixIndexing
}
protected final OOCInstruction.OOCType _ooctype;
protected final boolean _requiresLabelUpdate;
+ protected Set<OOCStream<?>> _inQueues;
+ protected Set<OOCStream<?>> _outQueues;
+ private boolean _failed;
protected OOCInstruction(OOCInstruction.OOCType type, String opcode,
String istr) {
this(type, null, opcode, istr);
@@ -54,6 +73,7 @@ public abstract class OOCInstruction extends Instruction {
instOpcode = opcode;
_requiresLabelUpdate = super.requiresLabelUpdate();
+ _failed = false;
}
@Override
@@ -90,10 +110,203 @@ public abstract class OOCInstruction extends Instruction {
ec.maintainLineageDebuggerInfo(this);
}
- protected void submitOOCTask(Runnable r, LocalTaskQueue<?>... queues) {
+ protected void addInStream(OOCStream<?>... queue) {
+ if (_inQueues == null)
+ _inQueues = new HashSet<>();
+ _inQueues.addAll(List.of(queue));
+ }
+
+ protected void addOutStream(OOCStream<?>... queue) {
+ // Currently same behavior as addInQueue
+ if (_outQueues == null)
+ _outQueues = new HashSet<>();
+ _outQueues.addAll(List.of(queue));
+ }
+
+ protected <T> OOCStream<T> createWritableStream() {
+ return new SubscribableTaskQueue<>();
+ }
+
+ protected <T, R> CompletableFuture<Void> filterOOC(OOCStream<T> qIn,
Consumer<T> processor, Function<T, Boolean> predicate, Runnable finalizer) {
+ if (_inQueues == null || _outQueues == null)
+ throw new NotImplementedException("filterOOC requires
manual specification of all input and output streams for error propagation");
+
+ return submitOOCTasks(qIn, processor, finalizer, predicate);
+ }
+
+ protected <T, R> CompletableFuture<Void> mapOOC(OOCStream<T> qIn,
OOCStream<R> qOut, Function<T, R> mapper) {
+ addInStream(qIn);
+ addOutStream(qOut);
+
+ return submitOOCTasks(qIn, tmp -> {
+ try {
+ R r = mapper.apply(tmp);
+ qOut.enqueue(r);
+ } catch (Exception e) {
+ throw e instanceof DMLRuntimeException ?
(DMLRuntimeException) e : new DMLRuntimeException(e);
+ }
+ }, qOut::closeInput);
+ }
+
+ protected <T, R, P> CompletableFuture<Void> joinOOC(OOCStream<T> qIn1,
OOCStream<T> qIn2, OOCStream<R> qOut, BiFunction<T, T, R> mapper, Function<T,
P> on) {
+ return joinOOC(qIn1, qIn2, qOut, mapper, on, on);
+ }
+
+ @SuppressWarnings("unchecked")
+ protected <T, R, P> CompletableFuture<Void> joinOOC(OOCStream<T> qIn1,
OOCStream<T> qIn2, OOCStream<R> qOut, BiFunction<T, T, R> mapper, Function<T,
P> onLeft, Function<T, P> onRight) {
+ addInStream(qIn1, qIn2);
+ addOutStream(qOut);
+
+ final CompletableFuture<Void> future = new
CompletableFuture<>();
+
+ // We need to construct our own stream to properly manage the
cached items in the hash join
+ CachingStream leftCache = qIn1.hasStreamCache() ?
qIn1.getStreamCache() : new
CachingStream((SubscribableTaskQueue<IndexedMatrixValue>)qIn1); // We have to
assume this generic type for now
+ CachingStream rightCache = qIn2.hasStreamCache() ?
qIn2.getStreamCache() : new
CachingStream((SubscribableTaskQueue<IndexedMatrixValue>)qIn2); // We have to
assume this generic type for now
+ leftCache.activateIndexing();
+ rightCache.activateIndexing();
+
+ final OOCJoin<P, MatrixIndexes> join = new OOCJoin<>((idx,
left, right) -> {
+ T leftObj = (T) leftCache.findCached(left);
+ T rightObj = (T) rightCache.findCached(right);
+ qOut.enqueue(mapper.apply(leftObj, rightObj));
+ });
+
+ submitOOCTasks(List.of(leftCache.getReadStream(),
rightCache.getReadStream()), (i, tmp) -> {
+ if (i == 0)
+ join.addLeft(onLeft.apply((T)tmp),
((IndexedMatrixValue) tmp).getIndexes());
+ else
+ join.addRight(onRight.apply((T)tmp),
((IndexedMatrixValue) tmp).getIndexes());
+ }, () -> {
+ join.close();
+ qOut.closeInput();
+ future.complete(null);
+ });
+
+ return future;
+ }
+
+ protected <T> CompletableFuture<Void> submitOOCTasks(final
List<OOCStream<T>> queues, BiConsumer<Integer, T> consumer, Runnable finalizer)
{
+ List<CompletableFuture<Void>> futures = new
ArrayList<>(queues.size());
+
+ for (int i = 0; i < queues.size(); i++)
+ futures.add(new CompletableFuture<>());
+
+ return submitOOCTasks(queues, consumer, finalizer, futures,
null);
+ }
+
+ protected <T> CompletableFuture<Void> submitOOCTasks(final
List<OOCStream<T>> queues, BiConsumer<Integer, T> consumer, Runnable finalizer,
List<CompletableFuture<Void>> futures, BiFunction<Integer, T, Boolean>
predicate) {
+ addInStream(queues.toArray(OOCStream[]::new));
+ ExecutorService pool = CommonThreadPool.get();
+
+ final List<AtomicInteger> activeTaskCtrs = new
ArrayList<>(queues.size());
+ final List<AtomicBoolean> streamsClosed = new
ArrayList<>(queues.size());
+
+ for (int i = 0; i < queues.size(); i++) {
+ activeTaskCtrs.add(new AtomicInteger(0));
+ streamsClosed.add(new AtomicBoolean(false));
+ }
+
+ final AtomicInteger globalTaskCtr = new AtomicInteger(0);
+ final CompletableFuture<Void> globalFuture =
CompletableFuture.allOf(futures.toArray(CompletableFuture[]::new));
+ final Runnable oocFinalizer = oocTask(finalizer, null,
Stream.concat(_outQueues.stream(),
_inQueues.stream()).toArray(OOCStream[]::new));
+ final Object globalLock = new Object();
+
+ int i = 0;
+ @SuppressWarnings("unused")
+ final int streamId = nextStreamId.getAndIncrement();
+ //System.out.println("New stream: (id " + streamId + ", size "
+ queues.size() + ", initiator '" + this.getClass().getSimpleName() + "')");
+
+ for (OOCStream<T> queue : queues) {
+ final int k = i;
+ final AtomicInteger localTaskCtr =
activeTaskCtrs.get(k);
+ final AtomicBoolean localStreamClosed =
streamsClosed.get(k);
+ final CompletableFuture<Void> localFuture =
futures.get(k);
+
+ //System.out.println("Substream (k " + k + ", id " +
streamId + ", type '" + queue.getClass().getSimpleName() + "', stream_id " +
queue.hashCode() + ")");
+ queue.setSubscriber(oocTask(() -> {
+ final T item = queue.dequeue();
+
+ if (predicate != null && item != null &&
!predicate.apply(k, item)) // Can get closed due to cancellation
+ return;
+
+ synchronized (globalLock) {
+ if (localFuture.isDone())
+ return;
+
+ globalTaskCtr.incrementAndGet();
+ }
+
+ localTaskCtr.incrementAndGet();
+
+ pool.submit(oocTask(() -> {
+ if(item != null) {
+ //System.out.println("Accept" +
((IndexedMatrixValue)item).getIndexes() + " (k " + k + ", id " + streamId +
")");
+ consumer.accept(k, item);
+ }
+ else {
+ //System.out.println("Close
substream (k " + k + ", id " + streamId + ")");
+ localStreamClosed.set(true);
+ }
+
+ boolean runFinalizer = false;
+
+ synchronized (globalLock) {
+ int localTasks =
localTaskCtr.decrementAndGet();
+ boolean finalizeStream =
localTasks == 0 && localStreamClosed.get();
+
+ int globalTasks =
globalTaskCtr.get() - 1;
+
+ if (finalizeStream ||
(globalFuture.isDone() && localTasks == 0)) {
+
localFuture.complete(null);
+
+ if
(globalFuture.isDone() && globalTasks == 0)
+ runFinalizer =
true;
+ }
+
+ globalTaskCtr.decrementAndGet();
+ }
+
+ if (runFinalizer)
+ oocFinalizer.run();
+ }, localFuture,
Stream.concat(_outQueues.stream(),
_inQueues.stream()).toArray(OOCStream[]::new)));
+ }, null, Stream.concat(_outQueues.stream(),
_inQueues.stream()).toArray(OOCStream[]::new)));
+
+ i++;
+ }
+
+ pool.shutdown();
+
+ globalFuture.whenComplete((res, e) -> {
+ if (globalFuture.isCancelled() ||
globalFuture.isCompletedExceptionally())
+ futures.forEach(f -> f.cancel(true));
+
+ boolean runFinalizer;
+
+ synchronized (globalLock) {
+ runFinalizer = globalTaskCtr.get() == 0;
+ }
+
+ if (runFinalizer)
+ oocFinalizer.run();
+
+ //System.out.println("Shutdown (id " + streamId + ")");
+ });
+ return globalFuture;
+ }
+
+ protected <T> CompletableFuture<Void> submitOOCTasks(OOCStream<T>
queue, Consumer<T> consumer, Runnable finalizer) {
+ return submitOOCTasks(List.of(queue), (i, tmp) ->
consumer.accept(tmp), finalizer);
+ }
+
+ protected <T> CompletableFuture<Void> submitOOCTasks(OOCStream<T>
queue, Consumer<T> consumer, Runnable finalizer, Function<T, Boolean>
predicate) {
+ return submitOOCTasks(List.of(queue), (i, tmp) ->
consumer.accept(tmp), finalizer, List.of(new CompletableFuture<Void>()), (i,
tmp) -> predicate.apply(tmp));
+ }
+
+ protected CompletableFuture<Void> submitOOCTask(Runnable r,
OOCStream<?>... queues) {
ExecutorService pool = CommonThreadPool.get();
+ final CompletableFuture<Void> future = new
CompletableFuture<>();
try {
- pool.submit(oocTask(r, queues));
+ pool.submit(oocTask(() ->
{r.run();future.complete(null);}, future, queues));
}
catch (Exception ex) {
throw new DMLRuntimeException(ex);
@@ -101,9 +314,11 @@ public abstract class OOCInstruction extends Instruction {
finally {
pool.shutdown();
}
+
+ return future;
}
- private Runnable oocTask(Runnable r, LocalTaskQueue<?>... queues) {
+ private Runnable oocTask(Runnable r, CompletableFuture<Void> future,
OOCStream<?>... queues) {
return () -> {
try {
r.run();
@@ -111,9 +326,16 @@ public abstract class OOCInstruction extends Instruction {
catch (Exception ex) {
DMLRuntimeException re = ex instanceof
DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
- for (LocalTaskQueue<?> q : queues) {
+ if (_failed) // Do avoid infinite cycles
+ throw re;
+
+ _failed = true;
+
+ for (OOCStream<?> q : queues)
q.propagateFailure(re);
- }
+
+ if (future != null)
+ future.completeExceptionally(re);
// Rethrow to ensure proper future handling
throw re;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java
new file mode 100644
index 0000000000..1a12cb138b
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+
+public interface OOCStream<T> extends OOCStreamable<T> {
+ void enqueue(T t);
+
+ T dequeue();
+
+ void closeInput();
+
+ LocalTaskQueue<T> toLocalTaskQueue();
+
+ void propagateFailure(DMLRuntimeException re);
+
+ boolean hasStreamCache();
+
+ CachingStream getStreamCache();
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java
new file mode 100644
index 0000000000..bdc4086bdc
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+public interface OOCStreamable<T> {
+ OOCStream<T> getReadStream();
+
+ OOCStream<T> getWriteStream();
+
+ boolean isProcessed();
+
+ void setSubscriber(Runnable subscriber);
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java
new file mode 100644
index 0000000000..6edc4ecf27
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/PlaybackStream.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+
+public class PlaybackStream implements OOCStream<IndexedMatrixValue>,
OOCStreamable<IndexedMatrixValue> {
+ private final CachingStream _streamCache;
+ private int _streamIdx;
+
+ public PlaybackStream(CachingStream streamCache) {
+ this._streamCache = streamCache;
+ this._streamIdx = 0;
+ }
+
+ @Override
+ public void enqueue(IndexedMatrixValue t) {
+ throw new DMLRuntimeException("Cannot enqueue to a playback
stream");
+ }
+
+ @Override
+ public void closeInput() {
+ throw new DMLRuntimeException("Cannot close a playback stream");
+ }
+
+ @Override
+ public LocalTaskQueue<IndexedMatrixValue> toLocalTaskQueue() {
+ final SubscribableTaskQueue<IndexedMatrixValue> q = new
SubscribableTaskQueue<>();
+ setSubscriber(() -> q.enqueue(dequeue()));
+ return q;
+ }
+
+ @Override
+ public synchronized IndexedMatrixValue dequeue() {
+ try {
+ return _streamCache.get(_streamIdx++);
+ } catch (InterruptedException e) {
+ throw new DMLRuntimeException(e);
+ }
+ }
+
+ @Override
+ public OOCStream<IndexedMatrixValue> getReadStream() {
+ return _streamCache.getReadStream();
+ }
+
+ @Override
+ public OOCStream<IndexedMatrixValue> getWriteStream() {
+ return _streamCache.getWriteStream();
+ }
+
+ @Override
+ public boolean isProcessed() {
+ return false;
+ }
+
+ @Override
+ public void setSubscriber(Runnable subscriber) {
+ _streamCache.setSubscriber(subscriber);
+ }
+
+ @Override
+ public void propagateFailure(DMLRuntimeException re) {
+ _streamCache.getWriteStream().propagateFailure(re);
+ }
+
+ @Override
+ public boolean hasStreamCache() {
+ return true;
+ }
+
+ @Override
+ public CachingStream getStreamCache() {
+ return _streamCache;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
index 3c78879b45..74b15c9fb0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
@@ -28,7 +28,6 @@ import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
@@ -75,7 +74,7 @@ public class ReblockOOCInstruction extends
ComputationOOCInstruction {
//TODO support other formats than binary
//create queue, spawn thread for asynchronous reading, and
return
- LocalTaskQueue<IndexedMatrixValue> q = new
LocalTaskQueue<IndexedMatrixValue>();
+ OOCStream<IndexedMatrixValue> q = createWritableStream();
submitOOCTask(() -> readBinaryBlock(q, min.getFileName()), q);
MatrixObject mout = ec.getMatrixObject(output);
@@ -83,7 +82,7 @@ public class ReblockOOCInstruction extends
ComputationOOCInstruction {
}
@SuppressWarnings("resource")
- private void readBinaryBlock(LocalTaskQueue<IndexedMatrixValue> q,
String fname) {
+ private void readBinaryBlock(OOCStream<IndexedMatrixValue> q, String
fname) {
try {
//prepare file access
JobConf job = new
JobConf(ConfigurationManager.getCachedJobConf());
@@ -102,7 +101,7 @@ public class ReblockOOCInstruction extends
ComputationOOCInstruction {
MatrixIndexes key = new MatrixIndexes();
MatrixBlock value = new MatrixBlock();
while( reader.next(key, value) )
- q.enqueueTask(new
IndexedMatrixValue(key, new MatrixBlock(value)));
+ q.enqueue(new
IndexedMatrixValue(key, new MatrixBlock(value)));
}
}
q.closeInput();
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java
deleted file mode 100644
index 6179811f7a..0000000000
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java
+++ /dev/null
@@ -1,116 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.runtime.instructions.ooc;
-
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
-import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
-import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
-
-
-/**
- * A wrapper around LocalTaskQueue to consume the source stream and reset to
- * consume again for other operators.
- * <p>
- * Uses OOCEvictionManager for out-of-core caching.
- *
- */
-public class ResettableStream extends LocalTaskQueue<IndexedMatrixValue> {
-
- // original live stream
- private final LocalTaskQueue<IndexedMatrixValue> _source;
-
- private static final IDSequence _streamSeq = new IDSequence();
- // stream identifier
- private final long _streamId;
-
- // block counter
- private int _numBlocks = 0;
-
-
- // state flags
- private boolean _cacheInProgress = true; // caching in progress, in the
first pass.
- private int _replayPosition = 0; // slider position in the stream
-
- public ResettableStream(LocalTaskQueue<IndexedMatrixValue> source) {
- this(source, _streamSeq.getNextID());
- }
- public ResettableStream(LocalTaskQueue<IndexedMatrixValue> source, long
streamId) {
- _source = source;
- _streamId = streamId;
- }
-
- /**
- * Dequeues a task. If it is the first, it reads from the disk and
stores in the cache.
- * For subsequent passes it reads from the memory.
- *
- * @return The next matrix value in the stream, or NO_MORE_TASKS
- */
- @Override
- public synchronized IndexedMatrixValue dequeueTask()
- throws InterruptedException {
- if (_cacheInProgress) {
- // First pass: Read value from the source and cache it,
and return.
- IndexedMatrixValue task = _source.dequeueTask();
- if (task != NO_MORE_TASKS) {
-
- OOCEvictionManager.put(_streamId, _numBlocks,
task);
- _numBlocks++;
-
- return task;
- } else {
- _cacheInProgress = false; // caching is complete
- _source.closeInput(); // close source stream
-
- // Notify all the waiting consumers waiting for
cache to fill with this stream
- notifyAll();
- return (IndexedMatrixValue) NO_MORE_TASKS;
- }
- } else {
- // Replay pass: read from the buffer
- if (_replayPosition < _numBlocks) {
- return OOCEvictionManager.get(_streamId,
_replayPosition++);
- } else {
- return (IndexedMatrixValue) NO_MORE_TASKS;
- }
- }
- }
-
- /**
- * Resets the stream to beginning to read the stream from start.
- * This can only be called once the stream is fully consumed once.
- */
- public synchronized void reset() throws InterruptedException {
- while (_cacheInProgress) {
- // Attempted to reset a stream that's not been fully
cached yet.
- wait();
- }
- _replayPosition = 0;
- }
-
- @Override
- public synchronized void closeInput() {
- _source.closeInput();
- }
-
- @Override
- public synchronized boolean isProcessed() {
- return false;
- }
-}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
new file mode 100644
index 0000000000..5f97bd99e9
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.ooc;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
+
+public class SubscribableTaskQueue<T> extends LocalTaskQueue<T> implements
OOCStream<T> {
+ private Runnable _subscriber;
+
+ @Override
+ public void enqueue(T t) {
+ try {
+ super.enqueueTask(t);
+ }
+ catch (InterruptedException e) {
+ throw new DMLRuntimeException(e);
+ }
+
+ if(_subscriber != null)
+ _subscriber.run();
+ }
+
+ @Override
+ public T dequeue() {
+ try {
+ return super.dequeueTask();
+ }
+ catch (InterruptedException e) {
+ throw new DMLRuntimeException(e);
+ }
+ }
+
+ @Override
+ public synchronized void closeInput() {
+ super.closeInput();
+
+ if(_subscriber != null) {
+ _subscriber.run();
+ _subscriber = null;
+ }
+ }
+
+ @Override
+ public LocalTaskQueue<T> toLocalTaskQueue() {
+ return this;
+ }
+
+ @Override
+ public OOCStream<T> getReadStream() {
+ return this;
+ }
+
+ @Override
+ public OOCStream<T> getWriteStream() {
+ return this;
+ }
+
+ @Override
+ public void setSubscriber(Runnable subscriber) {
+ int queueSize;
+
+ synchronized (this) {
+ if(_subscriber != null)
+ throw new DMLRuntimeException("Cannot set
multiple subscribers");
+
+ _subscriber = subscriber;
+ queueSize = _data.size();
+ queueSize += _closedInput ? 1 : 0; // To trigger the
NO_MORE_TASK element
+ }
+
+ for (int i = 0; i < queueSize; i++)
+ subscriber.run();
+ }
+
+ @Override
+ public synchronized void propagateFailure(DMLRuntimeException re) {
+ super.propagateFailure(re);
+
+ if(_subscriber != null)
+ _subscriber.run();
+ }
+
+ @Override
+ public boolean hasStreamCache() {
+ return false;
+ }
+
+ @Override
+ public CachingStream getStreamCache() {
+ return null;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
index b3f302c204..9040c369a2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TSMMOOCInstruction.java
@@ -66,7 +66,7 @@ public class TSMMOOCInstruction extends
ComputationOOCInstruction {
int nCols = (int) min.getDataCharacteristics().getCols();
int bLen = min.getDataCharacteristics().getBlocksize();
- LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
BinaryOperator plus =
InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
//validation check TODO extend compiler to not create OOC
otherwise
@@ -81,7 +81,7 @@ public class TSMMOOCInstruction extends
ComputationOOCInstruction {
try {
IndexedMatrixValue tmp = null;
// aggregate partial tsmm outputs into result as inputs
stream in
- while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
+ while((tmp = qIn.dequeue()) !=
LocalTaskQueue.NO_MORE_TASKS) {
MatrixBlock partialResult = ((MatrixBlock)
tmp.getValue())
.transposeSelfMatrixMultOperations(new
MatrixBlock(), _type);
resultBlock.binaryOperationsInPlace(plus,
partialResult);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
index baf3ecea24..fd80b4e6e9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java
@@ -21,7 +21,6 @@ package org.apache.sysds.runtime.instructions.ooc;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
@@ -44,11 +43,11 @@ public class TeeOOCInstruction extends
ComputationOOCInstruction {
public void processInstruction( ExecutionContext ec ) {
//get input stream
MatrixObject min = ec.getMatrixObject(input1);
- LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
//get output and create new resettable stream
MatrixObject mo = ec.getMatrixObject(output);
- mo.setStreamHandle(new ResettableStream(qIn));
+ mo.setStreamHandle(new CachingStream(qIn));
mo.setMetaData(min.getMetaData());
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
index 05e31830a5..6558145ec2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
@@ -19,10 +19,8 @@
package org.apache.sysds.runtime.instructions.ooc;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -53,26 +51,17 @@ public class TransposeOOCInstruction extends
ComputationOOCInstruction {
// Create thread and process the transpose operation
MatrixObject min = ec.getMatrixObject(input1);
- LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
- LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
+ OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qOut = createWritableStream();
ec.getMatrixObject(output).setStreamHandle(qOut);
- submitOOCTask(() -> {
- IndexedMatrixValue tmp = null;
- try {
- while ((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
- MatrixBlock inBlock =
(MatrixBlock)tmp.getValue();
- long oldRowIdx =
tmp.getIndexes().getRowIndex();
- long oldColIdx =
tmp.getIndexes().getColumnIndex();
+ mapOOC(qIn, qOut, tmp -> {
+ MatrixBlock inBlock = (MatrixBlock) tmp.getValue();
+ long oldRowIdx = tmp.getIndexes().getRowIndex();
+ long oldColIdx = tmp.getIndexes().getColumnIndex();
- MatrixBlock outBlock =
inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1);
- qOut.enqueueTask(new
IndexedMatrixValue(new MatrixIndexes(oldColIdx, oldRowIdx), outBlock));
- }
- qOut.closeInput();
- }
- catch(Exception ex) {
- throw new DMLRuntimeException(ex);
- }
- }, qIn, qOut);
+ MatrixBlock outBlock =
inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1);
+ return new IndexedMatrixValue(new
MatrixIndexes(oldColIdx, oldRowIdx), outBlock);
+ });
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
index 173486844a..08f00f86d2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
@@ -19,10 +19,8 @@
package org.apache.sysds.runtime.instructions.ooc;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
@@ -53,25 +51,15 @@ public class UnaryOOCInstruction extends
ComputationOOCInstruction {
UnaryOperator uop = (UnaryOperator) _uop;
// Create thread and process the unary operation
MatrixObject min = ec.getMatrixObject(input1);
- LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
- LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
+ OOCStream<IndexedMatrixValue> qIn = min.getStreamHandle();
+ OOCStream<IndexedMatrixValue> qOut = createWritableStream();
ec.getMatrixObject(output).setStreamHandle(qOut);
-
- submitOOCTask(() -> {
- IndexedMatrixValue tmp = null;
- try {
- while ((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
- IndexedMatrixValue tmpOut = new
IndexedMatrixValue();
- tmpOut.set(tmp.getIndexes(),
-
tmp.getValue().unaryOperations(uop, new MatrixBlock()));
- qOut.enqueueTask(tmpOut);
- }
- qOut.closeInput();
- }
- catch(Exception ex) {
- throw new DMLRuntimeException(ex);
- }
- }, qIn, qOut);
+ mapOOC(qIn, qOut, tmp -> {
+ IndexedMatrixValue tmpOut = new IndexedMatrixValue();
+ tmpOut.set(tmp.getIndexes(),
+ tmp.getValue().unaryOperations(uop, new
MatrixBlock()));
+ return tmpOut;
+ });
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
index f928f0440b..489b277f74 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/CMOperator.java
@@ -55,6 +55,13 @@ public class CMOperator extends MultiThreadedOperator
_numThreads = numThreads;
}
+ public CMOperator(CMOperator that) {
+ // Deep copy the stateful ValueFunction
+ fn = that.fn instanceof CM ? CM.getCMFnObject((CM)that.fn) :
that.fn;
+ aggOpType = that.aggOpType;
+ _numThreads = that._numThreads;
+ }
+
public AggregateOperationTypes getAggOpType() {
return aggOpType;
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/OOCJoin.java
b/src/main/java/org/apache/sysds/runtime/util/OOCJoin.java
new file mode 100644
index 0000000000..81265b8a2d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/util/OOCJoin.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.util;
+
+import org.apache.logging.log4j.util.TriConsumer;
+import org.apache.sysds.runtime.DMLRuntimeException;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class OOCJoin<T, O> {
+ private Map<T, O> left;
+ private Map<T, O> right;
+ private TriConsumer<T, O, O> emitter;
+
+ public OOCJoin(TriConsumer<T, O, O> emitter) {
+ this.left = new HashMap<>();
+ this.right = new HashMap<>();
+ this.emitter = emitter;
+ }
+
+ public void addLeft(T idx, O item) {
+ add(true, idx, item);
+ }
+
+ public void addRight(T idx, O item) {
+ add(false, idx, item);
+ }
+
+ public void close() {
+ synchronized (this) {
+ if (!left.isEmpty() || !right.isEmpty())
+ throw new DMLRuntimeException("There are still
unprocessed items in the OOC join");
+ }
+ }
+
+ public void add(boolean isLeft, T idx, O val) {
+ Map<T, O> lookup = isLeft ? right : left;
+ Map<T, O> store = isLeft ? left : right;
+ O val2;
+
+ synchronized (this) {
+ val2 = lookup.remove(idx);
+
+ if (val2 == null)
+ store.put(idx, val);
+ }
+
+ if (val2 != null)
+ emitter.accept(idx, isLeft ? val : val2, isLeft ? val2
: val);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java
new file mode 100644
index 0000000000..dfa9413bfb
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixMatrixTest.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class BinaryMatrixMatrixTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "BinaryMatrixMatrix";
+ private final static String TEST_DIR = "functions/ooc/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
BinaryMatrixMatrixTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-8;
+ private static final String INPUT_NAME_1 = "X";
+ private static final String INPUT_NAME_2 = "Y";
+ private static final String OUTPUT_NAME = "res";
+
+ private final static int rows = 1500;
+ private final static int cols = 1200;
+ private final static int maxVal = 7;
+ private final static double sparsity1 = 1;
+ private final static double sparsity2 = 0.05;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+ addTestConfiguration(TEST_NAME1, config);
+ }
+
+ @Test
+ public void testBinaryMatrixMatrixDenseDense() {
+ runBinaryMatrixMatrixTest(false, false);
+ }
+
+ @Test
+ public void testBinaryMatrixMatrixDenseSparse() {
+ runBinaryMatrixMatrixTest(false, true);
+ }
+
+ @Test
+ public void testBinaryMatrixMatrixSparseDense() {
+ runBinaryMatrixMatrixTest(true, false);
+ }
+
+ @Test
+ public void testBinaryMatrixMatrixSparseSparse() {
+ runBinaryMatrixMatrixTest(true, true);
+ }
+
+ private void runBinaryMatrixMatrixTest(boolean sparse1, boolean
sparse2) {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME1);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ programArgs = new String[] {"-explain", "-stats",
"-ooc", "-args", input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME)};
+
+ // 1. Generate the data in-memory as MatrixBlock objects
+ double[][] X_data = getRandomMatrix(rows, 1, 1, maxVal,
sparse1 ? sparsity2 : sparsity1, 7);
+ double[][] Y_data = getRandomMatrix(rows, 1, 0, 1,
sparse2 ? sparsity2 : sparsity1, 8);
+
+ // 2. Convert the double arrays to MatrixBlock objects
+ MatrixBlock X_mb =
DataConverter.convertToMatrixBlock(X_data);
+ MatrixBlock Y_mb =
DataConverter.convertToMatrixBlock(Y_data);
+
+ // 3. Create a binary matrix writer
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+ // 4. Write matrix A to a binary SequenceFile
+ writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1),
rows, cols, 1000, X_mb.getNonZeros());
+ writer.writeMatrixToHDFS(Y_mb, input(INPUT_NAME_2),
rows, cols, 1000, Y_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 +
".mtd"), Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, cols, 1000,
X_mb.getNonZeros()), Types.FileFormat.BINARY);
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 +
".mtd"), Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, cols, 1000,
Y_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+ runTest(true, false, null, -1);
+
+ //check tsmm OOC
+ Assert.assertTrue("OOC wasn't used for multiplication",
+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.MULT));
+
+ //compare results
+
+ // rerun without ooc flag
+ programArgs = new String[] {"-explain", "-stats",
"-args", input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME +
"_target")};
+ runTest(true, false, null, -1);
+
+ // compare matrices
+ MatrixBlock ret1 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ MatrixBlock ret2 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ TestUtils.compareMatrices(ret1, ret2, eps);
+ }
+ catch(IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java
new file mode 100644
index 0000000000..e84d36e41b
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryMatrixScalarTest.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class BinaryMatrixScalarTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "BinaryMatrixScalar";
+ private final static String TEST_DIR = "functions/ooc/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
BinaryMatrixScalarTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-8;
+ private static final String INPUT_NAME_1 = "X";
+ private static final String OUTPUT_NAME = "res";
+
+ private final static int rows = 1500;
+ private final static int cols = 1200;
+ private final static int maxVal = 7;
+ private final static double sparsity1 = 1;
+ private final static double sparsity2 = 0.05;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+ addTestConfiguration(TEST_NAME1, config);
+ }
+
+ @Test
+ public void testBinaryMatrixScalarDense() {
+ runBinaryMatrixScalarTest(false);
+ }
+
+ @Test
+ public void testBinaryMatrixScalarSparse() {
+ runBinaryMatrixScalarTest(true);
+ }
+
+ private void runBinaryMatrixScalarTest(boolean sparse) {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME1);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ programArgs = new String[] {"-explain", "-stats",
"-ooc", "-args", input(INPUT_NAME_1), output(OUTPUT_NAME)};
+
+ // 1. Generate the data in-memory as MatrixBlock objects
+ double[][] X_data = getRandomMatrix(rows, 1, 1, maxVal,
sparse ? sparsity2 : sparsity1, 7);
+
+ // 2. Convert the double arrays to MatrixBlock objects
+ MatrixBlock X_mb =
DataConverter.convertToMatrixBlock(X_data);
+
+ // 3. Create a binary matrix writer
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+ // 4. Write matrix A to a binary SequenceFile
+ writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME_1),
rows, cols, 1000, X_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 +
".mtd"), Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, cols, 1000,
X_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+ runTest(true, false, null, -1);
+
+ //check tsmm OOC
+ Assert.assertTrue("OOC wasn't used for division",
+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.DIV));
+ Assert.assertTrue("OOC wasn't used for addition",
+
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.PLUS));
+
+ //compare results
+
+ // rerun without ooc flag
+ programArgs = new String[] {"-explain", "-stats",
"-args", input(INPUT_NAME_1), output(OUTPUT_NAME + "_target")};
+ runTest(true, false, null, -1);
+
+ // compare matrices
+ MatrixBlock ret1 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ MatrixBlock ret2 =
DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"),
+ Types.FileFormat.BINARY, rows, cols, 1000);
+ TestUtils.compareMatrices(ret1, ret2, eps);
+ }
+ catch(IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/ooc/BinaryMatrixMatrix.dml
b/src/test/scripts/functions/ooc/BinaryMatrixMatrix.dml
new file mode 100644
index 0000000000..ad7ed6bb55
--- /dev/null
+++ b/src/test/scripts/functions/ooc/BinaryMatrixMatrix.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the input matrix as a stream
+X = read($1);
+Y = read($2);
+
+res = X * Y
+res = res * X
+
+write(res, $3, format="binary");
diff --git a/src/test/scripts/functions/ooc/BinaryMatrixScalar.dml
b/src/test/scripts/functions/ooc/BinaryMatrixScalar.dml
new file mode 100644
index 0000000000..e5b19fe5a7
--- /dev/null
+++ b/src/test/scripts/functions/ooc/BinaryMatrixScalar.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the input matrix as a stream
+X = read($1);
+
+OOC = 5 / X
+res = OOC + 3
+
+write(res, $2, format="binary");