This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new d38e56ccc6 [SYSTEMDS-3923] Improve exception handling OOC instructions
d38e56ccc6 is described below
commit d38e56ccc65c8935e3ae31787455cc7a4340b233
Author: Jannik Lindemann <[email protected]>
AuthorDate: Tue Oct 28 12:22:23 2025 +0100
[SYSTEMDS-3923] Improve exception handling OOC instructions
Closes #2346.
---
.../controlprogram/parfor/LocalTaskQueue.java | 21 +++-
.../ooc/AggregateUnaryOOCInstruction.java | 12 +--
.../instructions/ooc/BinaryOOCInstruction.java | 10 +-
.../ooc/MatrixVectorBinaryOOCInstruction.java | 13 +--
.../runtime/instructions/ooc/OOCInstruction.java | 35 +++++++
.../instructions/ooc/ReblockOOCInstruction.java | 8 +-
.../instructions/ooc/TransposeOOCInstruction.java | 12 +--
.../instructions/ooc/UnaryOOCInstruction.java | 11 +--
.../functions/ooc/OOCExceptionHandlingTest.java | 106 +++++++++++++++++++++
.../scripts/functions/ooc/OOCExceptionHandling.dml | 28 ++++++
10 files changed, 199 insertions(+), 57 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
index e1099f715b..350fc8de3b 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
@@ -23,6 +23,7 @@ import java.util.LinkedList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
/**
* This class provides a way of dynamic task distribution to multiple workers
@@ -43,7 +44,8 @@ public class LocalTaskQueue<T>
public static final Object NO_MORE_TASKS = null; //object to signal
NO_MORE_TASKS
private LinkedList<T> _data = null;
- private boolean _closedInput = false;
+ private boolean _closedInput = false;
+ private DMLRuntimeException _failure = null;
private static final Log LOG =
LogFactory.getLog(LocalTaskQueue.class.getName());
public LocalTaskQueue()
@@ -61,11 +63,14 @@ public class LocalTaskQueue<T>
public synchronized void enqueueTask( T t )
throws InterruptedException
{
- while( _data.size() + 1 > MAX_SIZE )
+ while( _data.size() + 1 > MAX_SIZE && _failure == null )
{
LOG.warn("MAX_SIZE of task queue reached.");
wait(); //max constraint reached, wait for read
}
+
+ if ( _failure != null )
+ throw _failure;
_data.addLast( t );
@@ -82,13 +87,16 @@ public class LocalTaskQueue<T>
public synchronized T dequeueTask()
throws InterruptedException
{
- while( _data.isEmpty() )
+ while( _data.isEmpty() && _failure == null )
{
if( !_closedInput )
wait(); // wait for writers
else
return (T)NO_MORE_TASKS;
}
+
+ if ( _failure != null )
+ throw _failure;
T t = _data.removeFirst();
@@ -111,6 +119,13 @@ public class LocalTaskQueue<T>
return _closedInput && _data.isEmpty();
}
+ public synchronized void propagateFailure(DMLRuntimeException failure) {
+ if (_failure == null) {
+ _failure = failure;
+ notifyAll();
+ }
+ }
+
@Override
public synchronized String toString()
{
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
index c01fb3fa37..8c8a64b022 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
@@ -90,9 +90,8 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
ec.getMatrixObject(output).setStreamHandle(qOut);
- ExecutorService pool = CommonThreadPool.get();
- try {
- pool.submit(() -> {
+
+ submitOOCTask(() -> {
IndexedMatrixValue tmp = null;
try {
while((tmp = q.dequeueTask())
!= LocalTaskQueue.NO_MORE_TASKS) {
@@ -152,12 +151,7 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
catch(Exception ex) {
throw new
DMLRuntimeException(ex);
}
- });
- } catch (Exception ex) {
- throw new DMLRuntimeException(ex);
- } finally {
- pool.shutdown();
- }
+ }, q, qOut);
}
// full aggregation
else {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
index fe76e60b9e..82ad12ae55 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
@@ -70,9 +70,7 @@ public class BinaryOOCInstruction extends
ComputationOOCInstruction {
LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
ec.getMatrixObject(output).setStreamHandle(qOut);
- ExecutorService pool = CommonThreadPool.get();
- try {
- pool.submit(() -> {
+ submitOOCTask(() -> {
IndexedMatrixValue tmp = null;
try {
while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
@@ -86,10 +84,6 @@ public class BinaryOOCInstruction extends
ComputationOOCInstruction {
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
- });
- }
- finally {
- pool.shutdown();
- }
+ }, qIn, qOut);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
index ae84e4b541..c1d1ed6ace 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
@@ -90,10 +90,7 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
BinaryOperator plus =
InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
ec.getMatrixObject(output).setStreamHandle(qOut);
- ExecutorService pool = CommonThreadPool.get();
- try {
- // Core logic: background thread
- pool.submit(() -> {
+ submitOOCTask(() -> {
IndexedMatrixValue tmp = null;
try {
while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
@@ -134,12 +131,6 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
finally {
qOut.closeInput();
}
- });
- } catch (Exception e) {
- throw new DMLRuntimeException(e);
- }
- finally {
- pool.shutdown();
- }
+ }, qIn, qOut);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index d55d1ee594..0d15949289 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -22,12 +22,16 @@ package org.apache.sysds.runtime.instructions.ooc;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
import java.util.HashMap;
+import java.util.concurrent.ExecutorService;
public abstract class OOCInstruction extends Instruction {
protected static final Log LOG =
LogFactory.getLog(OOCInstruction.class.getName());
@@ -86,6 +90,37 @@ public abstract class OOCInstruction extends Instruction {
ec.maintainLineageDebuggerInfo(this);
}
+ protected void submitOOCTask(Runnable r, LocalTaskQueue<?>... queues) {
+ ExecutorService pool = CommonThreadPool.get();
+ try {
+ pool.submit(oocTask(r, queues));
+ }
+ catch (Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ finally {
+ pool.shutdown();
+ }
+ }
+
+ private Runnable oocTask(Runnable r, LocalTaskQueue<?>... queues) {
+ return () -> {
+ try {
+ r.run();
+ }
+ catch (Exception ex) {
+ DMLRuntimeException re = ex instanceof
DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
+
+ for (LocalTaskQueue<?> q : queues) {
+ q.propagateFailure(re);
+ }
+
+ // Rethrow to ensure proper future handling
+ throw re;
+ }
+ };
+ }
+
/**
* Tracks blocks and their counts to enable early emission
* once all blocks for a given index are processed.
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
index 9a7059be51..06386c5d66 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
@@ -79,13 +79,7 @@ public class ReblockOOCInstruction extends
ComputationOOCInstruction {
//create queue, spawn thread for asynchronous reading, and
return
LocalTaskQueue<IndexedMatrixValue> q = new
LocalTaskQueue<IndexedMatrixValue>();
- ExecutorService pool = CommonThreadPool.get();
- try {
- pool.submit(() -> readBinaryBlock(q,
min.getFileName()));
- }
- finally {
- pool.shutdown();
- }
+ submitOOCTask(() -> readBinaryBlock(q, min.getFileName()), q);
MatrixObject mout = ec.getMatrixObject(output);
mout.setStreamHandle(q);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
index 212d0d5c56..fce5408960 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
@@ -60,10 +60,7 @@ public class TransposeOOCInstruction extends
ComputationOOCInstruction {
LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
ec.getMatrixObject(output).setStreamHandle(qOut);
-
- ExecutorService pool = CommonThreadPool.get();
- try {
- pool.submit(() -> {
+ submitOOCTask(() -> {
IndexedMatrixValue tmp = null;
try {
while ((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
@@ -79,11 +76,6 @@ public class TransposeOOCInstruction extends
ComputationOOCInstruction {
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
- });
- } catch (Exception ex) {
- throw new DMLRuntimeException(ex);
- } finally {
- pool.shutdown();
- }
+ }, qIn, qOut);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
index 13cd5463ed..63f42f5bf1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
@@ -61,9 +61,7 @@ public class UnaryOOCInstruction extends
ComputationOOCInstruction {
ec.getMatrixObject(output).setStreamHandle(qOut);
- ExecutorService pool = CommonThreadPool.get();
- try {
- pool.submit(() -> {
+ submitOOCTask(() -> {
IndexedMatrixValue tmp = null;
try {
while ((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
@@ -77,11 +75,6 @@ public class UnaryOOCInstruction extends
ComputationOOCInstruction {
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
- });
- } catch (Exception ex) {
- throw new DMLRuntimeException(ex);
- } finally {
- pool.shutdown();
- }
+ }, qIn, qOut);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java
new file mode 100644
index 0000000000..3bd32d7eff
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class OOCExceptionHandlingTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "OOCExceptionHandling";
+ private final static String TEST_DIR = "functions/ooc/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
OOCExceptionHandlingTest.class.getSimpleName() + "/";
+ private static final String INPUT_NAME = "X";
+ private static final String INPUT_NAME_2 = "Y";
+ private static final String OUTPUT_NAME = "res";
+
+ private final static int rows = 1000;
+ private final static int cols = 1000;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+ addTestConfiguration(TEST_NAME1, config);
+ }
+
+ @Test
+ public void runOOCExceptionHandlingTest1() {
+ runOOCExceptionHandlingTest(500);
+ }
+
+ @Test
+ public void runOOCExceptionHandlingTest2() {
+ runOOCExceptionHandlingTest(750);
+ }
+
+
+ private void runOOCExceptionHandlingTest(int misalignVals) {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME1);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ programArgs = new String[] {"-explain", "-stats",
"-ooc", "-args", input(INPUT_NAME), input(INPUT_NAME_2), output(OUTPUT_NAME)};
+
+ // 1. Generate the data in-memory as MatrixBlock objects
+ double[][] A_data = getRandomMatrix(rows, cols, 1, 2,
1, 7);
+ double[][] B_data = getRandomMatrix(rows, 1, 1, 2, 1,
7);
+
+ // 2. Convert the double arrays to MatrixBlock objects
+ MatrixBlock A_mb =
DataConverter.convertToMatrixBlock(A_data);
+ MatrixBlock B_mb =
DataConverter.convertToMatrixBlock(B_data);
+
+ // 3. Create a binary matrix writer
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+ // 4. Write matrix A to a binary SequenceFile
+
+ // Here, we write two faulty matrices which will only
be recognized at runtime
+ writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows,
cols, misalignVals, A_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"),
Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, cols, 1000,
A_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+ writer.writeMatrixToHDFS(B_mb, input(INPUT_NAME_2),
rows, 1, 1000, B_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 +
".mtd"), Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, 1, 1000,
B_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+ runTest(true, true, null, -1);
+ }
+ catch(IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/ooc/OOCExceptionHandling.dml
b/src/test/scripts/functions/ooc/OOCExceptionHandling.dml
new file mode 100644
index 0000000000..6b7dc6038e
--- /dev/null
+++ b/src/test/scripts/functions/ooc/OOCExceptionHandling.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the input matrix as a stream
+X = read($1);
+b = read($2);
+
+OOC = colSums(X %*% b);
+
+write(OOC, $3, format="binary");