This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 78c8d7fccb [SYSTEMDS-3924] New primitive for OOC stream creation /
reset
78c8d7fccb is described below
commit 78c8d7fccbeb12665939cac690f01215fb354249
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Oct 18 15:06:20 2025 +0200
[SYSTEMDS-3924] New primitive for OOC stream creation / reset
This patch improves the OOC backend by a new primitive that
automatically creates or resets existing OOC streams on
getStreamHandle such that OOC instruction don't need to handle this
issue individually, but can still probe the existence of active streams
via hasStreamHandle. As a result if there are OOC instructions that
consume materialized intermediates, we automatically create new
streams from these intermediates.
---
.../sysds/hops/rewrite/RewriteInjectOOCTee.java | 4 +-
.../controlprogram/caching/CacheableData.java | 44 +++++++++++++++++++++-
.../controlprogram/parfor/LocalTaskQueue.java | 4 ++
.../runtime/instructions/ooc/ResettableStream.java | 7 +++-
.../apache/sysds/runtime/util/UtilFunctions.java | 24 ++++++++++++
.../apache/sysds/test/functions/ooc/lmDSTest.java | 7 +++-
src/test/scripts/functions/ooc/lmDS.dml | 1 +
7 files changed, 86 insertions(+), 5 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
index 2e849f74fb..54dffa263e 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
@@ -51,6 +51,8 @@ import java.util.Set;
*/
public class RewriteInjectOOCTee extends HopRewriteRule {
+ public static boolean APPLY_ONLY_XtX_PATTERN = false;
+
private static final Set<Long> rewrittenHops = new HashSet<>();
private static final Map<Long, Hop> handledHop = new HashMap<>();
@@ -140,7 +142,7 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
&& hop.getDataType().isMatrix()
&& !HopRewriteUtils.isData(hop, OpOpData.TEE)
&& hop.getParent().size() > 1
- && isSelfTranposePattern(hop)) //FIXME remove
+ && (!APPLY_ONLY_XtX_PATTERN ||
isSelfTranposePattern(hop))) //FIXME remove
{
rewriteCandidates.add(hop);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 7517bf0e35..3eb9a320da 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -25,6 +25,7 @@ import java.lang.ref.SoftReference;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.stream.LongStream;
import org.apache.commons.lang3.mutable.MutableBoolean;
import org.apache.commons.logging.Log;
@@ -48,6 +49,7 @@ import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
+import org.apache.sysds.runtime.instructions.ooc.ResettableStream;
import org.apache.sysds.runtime.instructions.spark.data.BroadcastObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
@@ -55,12 +57,14 @@ import org.apache.sysds.runtime.io.FileFormatProperties;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.io.ReaderWriterFederated;
import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaData;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.LocalFileUtils;
+import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
@@ -466,8 +470,44 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
}
public LocalTaskQueue<IndexedMatrixValue> getStreamHandle() {
+ if( !hasStreamHandle() ) {
+ _streamHandle = new LocalTaskQueue<>();
+ DataCharacteristics dc = getDataCharacteristics();
+ MatrixBlock src = (MatrixBlock)acquireReadAndRelease();
+ LongStream.range(0, dc.getNumBlocks())
+ .mapToObj(i ->
UtilFunctions.createIndexedMatrixBlock(src, dc, i))
+ .forEach( blk -> {
+ try{
+ _streamHandle.enqueueTask(blk);
+ }
+ catch(Exception ex) {
+ throw new
DMLRuntimeException(ex);
+ }});
+ _streamHandle.closeInput();
+ }
+ else if(_streamHandle != null && _streamHandle.isProcessed()
+ && _streamHandle instanceof ResettableStream)
+ {
+ try {
+ ((ResettableStream)_streamHandle).reset();
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
return _streamHandle;
}
+
+ /**
+ * Probes if stream handle is existing, because
<code>getStreamHandle<code>
+ * creates a new stream if not existing.
+ *
+ * @return true if existing, false otherwise
+ */
+ public boolean hasStreamHandle() {
+ return _streamHandle != null && !_streamHandle.isProcessed();
+ }
@SuppressWarnings({ "rawtypes", "unchecked" })
public void setBroadcastHandle( BroadcastObject bc ) {
@@ -592,7 +632,7 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
//mark for initial local write despite
read operation
_requiresLocalWrite = false;
}
- else if( getStreamHandle() != null ) {
+ else if( hasStreamHandle() ) {
_data = readBlobFromStream(
getStreamHandle() );
}
else if( getRDDHandle()==null ||
getRDDHandle().allowsShortCircuitRead() ) {
@@ -909,7 +949,7 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
// a) get the matrix
boolean federatedWrite = (outputFormat != null ) &&
outputFormat.contains("federated");
- if(getStreamHandle()!=null) {
+ if(hasStreamHandle()) {
try {
long totalNnz =
writeStreamToHDFS(fName, outputFormat, replication, formatProperties);
updateDataCharacteristics(new
MatrixCharacteristics(
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
index cef7015e06..e1099f715b 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
@@ -106,6 +106,10 @@ public class LocalTaskQueue<T>
_closedInput = true;
notifyAll(); //notify all waiting readers
}
+
+ public synchronized boolean isProcessed() {
+ return _closedInput && _data.isEmpty();
+ }
@Override
public synchronized String toString()
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java
index e6f30cae60..038e1a8b98 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ResettableStream.java
@@ -83,7 +83,7 @@ public class ResettableStream extends
LocalTaskQueue<IndexedMatrixValue> {
* This can only be called once the stream is fully consumed once.
*/
public synchronized void reset() throws InterruptedException {
- if (_cacheInProgress) {
+ while (_cacheInProgress) {
// Attempted to reset a stream that's not been fully
cached yet.
wait();
}
@@ -94,4 +94,9 @@ public class ResettableStream extends
LocalTaskQueue<IndexedMatrixValue> {
public synchronized void closeInput() {
_source.closeInput();
}
+
+ @Override
+ public synchronized boolean isProcessed() {
+ return false;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index cc370d6ae9..04a0bd1ab8 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -53,8 +53,10 @@ import org.apache.sysds.runtime.frame.data.columns.CharArray;
import org.apache.sysds.runtime.frame.data.columns.HashIntegerArray;
import org.apache.sysds.runtime.frame.data.columns.HashLongArray;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.Pair;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.TensorCharacteristics;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
@@ -1471,4 +1473,26 @@ public class UtilFunctions {
return joined.split("\\s+");
}
+
+ public static IndexedMatrixValue createIndexedMatrixBlock(MatrixBlock
mb, DataCharacteristics mc, long ix) {
+ try {
+ //compute block indexes
+ long blockRow = ix / mc.getNumColBlocks();
+ long blockCol = ix % mc.getNumColBlocks();
+ //compute block sizes
+ int maxRow =
UtilFunctions.computeBlockSize(mc.getRows(), blockRow+1, mc.getBlocksize());
+ int maxCol =
UtilFunctions.computeBlockSize(mc.getCols(), blockCol+1, mc.getBlocksize());
+ //copy sub-matrix to block
+ MatrixBlock block = new MatrixBlock(maxRow, maxCol,
mb.isInSparseFormat());
+ int row_offset = (int)blockRow*mc.getBlocksize();
+ int col_offset = (int)blockCol*mc.getBlocksize();
+ block = mb.slice( row_offset, row_offset+maxRow-1,
+ col_offset, col_offset+maxCol-1, false, block );
+ //create key-value pair
+ return new IndexedMatrixValue(new
MatrixIndexes(blockRow+1, blockCol+1), block);
+ }
+ catch(DMLRuntimeException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
}
\ No newline at end of file
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java
index e6a147775f..0381ffb399 100644
--- a/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java
@@ -20,6 +20,7 @@
package org.apache.sysds.test.functions.ooc;
import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.rewrite.RewriteInjectOOCTee;
import org.apache.sysds.runtime.io.MatrixWriter;
import org.apache.sysds.runtime.io.MatrixWriterFactory;
import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
@@ -70,9 +71,12 @@ public class lmDSTest extends AutomatedTestBase {
private void runMatrixVectorMultiplicationTest(int cols)
{
Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
-
+ boolean oldFlag = RewriteInjectOOCTee.APPLY_ONLY_XtX_PATTERN;
+
try
{
+ RewriteInjectOOCTee.APPLY_ONLY_XtX_PATTERN = true;
+
getAndLoadTestConfiguration(TEST_NAME1);
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
@@ -117,6 +121,7 @@ public class lmDSTest extends AutomatedTestBase {
}
finally {
resetExecMode(platformOld);
+ RewriteInjectOOCTee.APPLY_ONLY_XtX_PATTERN = oldFlag;
}
}
}
diff --git a/src/test/scripts/functions/ooc/lmDS.dml
b/src/test/scripts/functions/ooc/lmDS.dml
index 930d956c50..32b84404b0 100644
--- a/src/test/scripts/functions/ooc/lmDS.dml
+++ b/src/test/scripts/functions/ooc/lmDS.dml
@@ -25,4 +25,5 @@ y = read($2)
XtX = t(X) %*% X; # 500 x 500
Xty = t(X) %*% y; # 500 x 1
R = solve(XtX, Xty)
+print(sum(R!=0))
write(R, $3, format="binary")