This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new d38e56ccc6 [SYSTEMDS-3923] Improve exception handling OOC instructions
d38e56ccc6 is described below

commit d38e56ccc65c8935e3ae31787455cc7a4340b233
Author: Jannik Lindemann <[email protected]>
AuthorDate: Tue Oct 28 12:22:23 2025 +0100

    [SYSTEMDS-3923] Improve exception handling OOC instructions
    
    Closes #2346.
---
 .../controlprogram/parfor/LocalTaskQueue.java      |  21 +++-
 .../ooc/AggregateUnaryOOCInstruction.java          |  12 +--
 .../instructions/ooc/BinaryOOCInstruction.java     |  10 +-
 .../ooc/MatrixVectorBinaryOOCInstruction.java      |  13 +--
 .../runtime/instructions/ooc/OOCInstruction.java   |  35 +++++++
 .../instructions/ooc/ReblockOOCInstruction.java    |   8 +-
 .../instructions/ooc/TransposeOOCInstruction.java  |  12 +--
 .../instructions/ooc/UnaryOOCInstruction.java      |  11 +--
 .../functions/ooc/OOCExceptionHandlingTest.java    | 106 +++++++++++++++++++++
 .../scripts/functions/ooc/OOCExceptionHandling.dml |  28 ++++++
 10 files changed, 199 insertions(+), 57 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
index e1099f715b..350fc8de3b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java
@@ -23,6 +23,7 @@ import java.util.LinkedList;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
 
 /**
  * This class provides a way of dynamic task distribution to multiple workers
@@ -43,7 +44,8 @@ public class LocalTaskQueue<T>
        public static final Object NO_MORE_TASKS = null; //object to signal 
NO_MORE_TASKS
        
        private LinkedList<T>  _data        = null;
-       private boolean            _closedInput = false; 
+       private boolean            _closedInput = false;
+       private DMLRuntimeException _failure = null;
        private static final Log LOG = 
LogFactory.getLog(LocalTaskQueue.class.getName());
        
        public LocalTaskQueue()
@@ -61,11 +63,14 @@ public class LocalTaskQueue<T>
        public synchronized void enqueueTask( T t ) 
                throws InterruptedException
        {
-               while( _data.size() + 1 > MAX_SIZE )
+               while( _data.size() + 1 > MAX_SIZE && _failure == null )
                {
                        LOG.warn("MAX_SIZE of task queue reached.");
                        wait(); //max constraint reached, wait for read
                }
+
+               if ( _failure != null )
+                       throw _failure;
                
                _data.addLast( t );
                
@@ -82,13 +87,16 @@ public class LocalTaskQueue<T>
        public synchronized T dequeueTask() 
                throws InterruptedException
        {
-               while( _data.isEmpty() )
+               while( _data.isEmpty() && _failure == null )
                {
                        if( !_closedInput )
                                wait(); // wait for writers
                        else
                                return (T)NO_MORE_TASKS; 
                }
+
+               if ( _failure != null )
+                       throw _failure;
                
                T t = _data.removeFirst();
                
@@ -111,6 +119,13 @@ public class LocalTaskQueue<T>
                return _closedInput && _data.isEmpty();
        }
 
+       public synchronized void propagateFailure(DMLRuntimeException failure) {
+               if (_failure == null) {
+                       _failure = failure;
+                       notifyAll();
+               }
+       }
+
        @Override
        public synchronized String toString() 
        {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
index c01fb3fa37..8c8a64b022 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
@@ -90,9 +90,8 @@ public class AggregateUnaryOOCInstruction extends 
ComputationOOCInstruction {
 
                        LocalTaskQueue<IndexedMatrixValue> qOut = new 
LocalTaskQueue<>();
                        ec.getMatrixObject(output).setStreamHandle(qOut);
-                       ExecutorService pool = CommonThreadPool.get();
-                       try {
-                               pool.submit(() -> {
+
+                       submitOOCTask(() -> {
                                        IndexedMatrixValue tmp = null;
                                        try {
                                                while((tmp = q.dequeueTask()) 
!= LocalTaskQueue.NO_MORE_TASKS) {
@@ -152,12 +151,7 @@ public class AggregateUnaryOOCInstruction extends 
ComputationOOCInstruction {
                                        catch(Exception ex) {
                                                throw new 
DMLRuntimeException(ex);
                                        }
-                               });
-                       } catch (Exception ex) {
-                               throw new DMLRuntimeException(ex);
-                       } finally {
-                               pool.shutdown();
-                       }
+                       }, q, qOut);
                }
                // full aggregation
                else {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
index fe76e60b9e..82ad12ae55 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java
@@ -70,9 +70,7 @@ public class BinaryOOCInstruction extends 
ComputationOOCInstruction {
                LocalTaskQueue<IndexedMatrixValue> qOut = new 
LocalTaskQueue<>();
                ec.getMatrixObject(output).setStreamHandle(qOut);
                
-               ExecutorService pool = CommonThreadPool.get();
-               try {
-                       pool.submit(() -> {
+               submitOOCTask(() -> {
                                IndexedMatrixValue tmp = null;
                                try {
                                        while((tmp = qIn.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
@@ -86,10 +84,6 @@ public class BinaryOOCInstruction extends 
ComputationOOCInstruction {
                                catch(Exception ex) {
                                        throw new DMLRuntimeException(ex);
                                }
-                       });
-               }
-               finally {
-                       pool.shutdown();
-               }
+               }, qIn, qOut);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
index ae84e4b541..c1d1ed6ace 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
@@ -90,10 +90,7 @@ public class MatrixVectorBinaryOOCInstruction extends 
ComputationOOCInstruction
                BinaryOperator plus = 
InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
                ec.getMatrixObject(output).setStreamHandle(qOut);
 
-               ExecutorService pool = CommonThreadPool.get();
-               try {
-                       // Core logic: background thread
-                       pool.submit(() -> {
+               submitOOCTask(() -> {
                                IndexedMatrixValue tmp = null;
                                try {
                                        while((tmp = qIn.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
@@ -134,12 +131,6 @@ public class MatrixVectorBinaryOOCInstruction extends 
ComputationOOCInstruction
                                finally {
                                        qOut.closeInput();
                                }
-                       });
-               } catch (Exception e) {
-                       throw new DMLRuntimeException(e);
-               }
-               finally {
-                       pool.shutdown();
-               }
+               }, qIn, qOut);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index d55d1ee594..0d15949289 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -22,12 +22,16 @@ package org.apache.sysds.runtime.instructions.ooc;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 
 import java.util.HashMap;
+import java.util.concurrent.ExecutorService;
 
 public abstract class OOCInstruction extends Instruction {
        protected static final Log LOG = 
LogFactory.getLog(OOCInstruction.class.getName());
@@ -86,6 +90,37 @@ public abstract class OOCInstruction extends Instruction {
                        ec.maintainLineageDebuggerInfo(this);
        }
 
+       protected void submitOOCTask(Runnable r, LocalTaskQueue<?>... queues) {
+               ExecutorService pool = CommonThreadPool.get();
+               try {
+                       pool.submit(oocTask(r, queues));
+               }
+               catch (Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+               finally {
+                       pool.shutdown();
+               }
+       }
+
+       private Runnable oocTask(Runnable r, LocalTaskQueue<?>... queues) {
+               return () -> {
+                       try {
+                               r.run();
+                       }
+                       catch (Exception ex) {
+                               DMLRuntimeException re = ex instanceof 
DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex);
+
+                               for (LocalTaskQueue<?> q : queues) {
+                                       q.propagateFailure(re);
+                               }
+
+                               // Rethrow to ensure proper future handling
+                               throw re;
+                       }
+               };
+       }
+
        /**
         * Tracks blocks and their counts to enable early emission
         * once all blocks for a given index are processed.
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
index 9a7059be51..06386c5d66 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java
@@ -79,13 +79,7 @@ public class ReblockOOCInstruction extends 
ComputationOOCInstruction {
                
                //create queue, spawn thread for asynchronous reading, and 
return
                LocalTaskQueue<IndexedMatrixValue> q = new 
LocalTaskQueue<IndexedMatrixValue>();
-               ExecutorService pool = CommonThreadPool.get();
-               try {
-                       pool.submit(() -> readBinaryBlock(q, 
min.getFileName()));
-               }
-               finally {
-                       pool.shutdown();
-               }
+               submitOOCTask(() -> readBinaryBlock(q, min.getFileName()), q);
                
                MatrixObject mout = ec.getMatrixObject(output);
                mout.setStreamHandle(q);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
index 212d0d5c56..fce5408960 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java
@@ -60,10 +60,7 @@ public class TransposeOOCInstruction extends 
ComputationOOCInstruction {
                LocalTaskQueue<IndexedMatrixValue> qOut = new 
LocalTaskQueue<>();
                ec.getMatrixObject(output).setStreamHandle(qOut);
 
-
-               ExecutorService pool = CommonThreadPool.get();
-               try {
-                       pool.submit(() -> {
+               submitOOCTask(() -> {
                                IndexedMatrixValue tmp = null;
                                try {
                                        while ((tmp = qIn.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
@@ -79,11 +76,6 @@ public class TransposeOOCInstruction extends 
ComputationOOCInstruction {
                                catch(Exception ex) {
                                        throw new DMLRuntimeException(ex);
                                }
-                       });
-               } catch (Exception ex) {
-                       throw new DMLRuntimeException(ex);
-               } finally {
-                       pool.shutdown();
-               }
+               }, qIn, qOut);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
index 13cd5463ed..63f42f5bf1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java
@@ -61,9 +61,7 @@ public class UnaryOOCInstruction extends 
ComputationOOCInstruction {
                ec.getMatrixObject(output).setStreamHandle(qOut);
 
 
-               ExecutorService pool = CommonThreadPool.get();
-               try {
-                       pool.submit(() -> {
+               submitOOCTask(() -> {
                                IndexedMatrixValue tmp = null;
                                try {
                                        while ((tmp = qIn.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
@@ -77,11 +75,6 @@ public class UnaryOOCInstruction extends 
ComputationOOCInstruction {
                                catch(Exception ex) {
                                        throw new DMLRuntimeException(ex);
                                }
-                       });
-               } catch (Exception ex) {
-                       throw new DMLRuntimeException(ex);
-               } finally {
-                       pool.shutdown();
-               }
+               }, qIn, qOut);
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java
new file mode 100644
index 0000000000..3bd32d7eff
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class OOCExceptionHandlingTest extends AutomatedTestBase {
+       private final static String TEST_NAME1 = "OOCExceptionHandling";
+       private final static String TEST_DIR = "functions/ooc/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
OOCExceptionHandlingTest.class.getSimpleName() + "/";
+       private static final String INPUT_NAME = "X";
+       private static final String INPUT_NAME_2 = "Y";
+       private static final String OUTPUT_NAME = "res";
+
+       private final static int rows = 1000;
+       private final static int cols = 1000;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               TestConfiguration config = new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+               addTestConfiguration(TEST_NAME1, config);
+       }
+
+       @Test
+       public void runOOCExceptionHandlingTest1() {
+               runOOCExceptionHandlingTest(500);
+       }
+
+       @Test
+       public void runOOCExceptionHandlingTest2() {
+               runOOCExceptionHandlingTest(750);
+       }
+
+
+       private void runOOCExceptionHandlingTest(int misalignVals) {
+               Types.ExecMode platformOld = 
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME1);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+                       programArgs = new String[] {"-explain", "-stats", 
"-ooc", "-args", input(INPUT_NAME), input(INPUT_NAME_2), output(OUTPUT_NAME)};
+
+                       // 1. Generate the data in-memory as MatrixBlock objects
+                       double[][] A_data = getRandomMatrix(rows, cols, 1, 2, 
1, 7);
+                       double[][] B_data = getRandomMatrix(rows, 1, 1, 2, 1, 
7);
+
+                       // 2. Convert the double arrays to MatrixBlock objects
+                       MatrixBlock A_mb = 
DataConverter.convertToMatrixBlock(A_data);
+                       MatrixBlock B_mb = 
DataConverter.convertToMatrixBlock(B_data);
+
+                       // 3. Create a binary matrix writer
+                       MatrixWriter writer = 
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+                       // 4. Write matrix A to a binary SequenceFile
+
+                       // Here, we write two faulty matrices which will only 
be recognized at runtime
+                       writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, 
cols, misalignVals, A_mb.getNonZeros());
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), 
Types.ValueType.FP64,
+                               new MatrixCharacteristics(rows, cols, 1000, 
A_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+                       writer.writeMatrixToHDFS(B_mb, input(INPUT_NAME_2), 
rows, 1, 1000, B_mb.getNonZeros());
+                       HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + 
".mtd"), Types.ValueType.FP64,
+                               new MatrixCharacteristics(rows, 1, 1000, 
B_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+                       runTest(true, true, null, -1);
+               }
+               catch(IOException e) {
+                       throw new RuntimeException(e);
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+}
diff --git a/src/test/scripts/functions/ooc/OOCExceptionHandling.dml 
b/src/test/scripts/functions/ooc/OOCExceptionHandling.dml
new file mode 100644
index 0000000000..6b7dc6038e
--- /dev/null
+++ b/src/test/scripts/functions/ooc/OOCExceptionHandling.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the input matrix as a stream
+X = read($1);
+b = read($2);
+
+OOC = colSums(X %*% b);
+
+write(OOC, $3, format="binary");

Reply via email to