This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 559f770294c588792825b5a15204a4f4e24ff3a4
Author: Janardhan Pulivarthi <j...@protonmail.com>
AuthorDate: Sun Jul 20 10:58:22 2025 +0200

    [SYSTEMDS-3730] Multi-threaded reverse operations
    
    Closes #2290.
---
 src/main/java/org/apache/sysds/hops/ReorgOp.java   |  18 +++-
 src/main/java/org/apache/sysds/lops/Transform.java |   2 +-
 .../instructions/cp/ReorgCPInstruction.java        |   9 +-
 .../sysds/runtime/matrix/data/LibMatrixReorg.java  |  69 +++++++++++++-
 .../test/functions/reorg/FullReverseTest.java      | 101 ++++++++++++++++++++-
 5 files changed, 188 insertions(+), 11 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java 
b/src/main/java/org/apache/sysds/hops/ReorgOp.java
index df6b4381ae..bd4fdc4f1d 100644
--- a/src/main/java/org/apache/sysds/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java
@@ -118,7 +118,8 @@ public class ReorgOp extends MultiThreadedHop
        @Override
        public boolean isMultiThreadedOpType() {
                return _op == ReOrgOp.TRANS
-                       || _op == ReOrgOp.SORT;
+                       || _op == ReOrgOp.SORT
+                       || _op == ReOrgOp.REV;
        }
 
        @Override
@@ -148,11 +149,22 @@ public class ReorgOp extends MultiThreadedHop
                                }
                                break;
                        }
-                       case DIAG:
+                       case DIAG: {
+                               Transform transform1 = new Transform(
+                                               
getInput().get(0).constructLops(),
+                                               _op, getDataType(), 
getValueType(), et);
+                               setOutputDimensions(transform1);
+                               setLineNumbers(transform1);
+                               setLops(transform1);
+                               break;
+                               }
                        case REV: {
+                               long numel = getDim1() * getDim2();
+                               int k = (numel < 3000_000) ?
+                                               1 : 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                                Transform transform1 = new Transform(
                                        getInput().get(0).constructLops(),
-                                       _op, getDataType(), getValueType(), et);
+                                       _op, getDataType(), getValueType(), et, 
k);
                                setOutputDimensions(transform1);
                                setLineNumbers(transform1);
                                setLops(transform1);
diff --git a/src/main/java/org/apache/sysds/lops/Transform.java 
b/src/main/java/org/apache/sysds/lops/Transform.java
index deca6c7d89..0ac36a37e4 100644
--- a/src/main/java/org/apache/sysds/lops/Transform.java
+++ b/src/main/java/org/apache/sysds/lops/Transform.java
@@ -180,7 +180,7 @@ public class Transform extends Lop
                sb.append( this.prepOutputOperand(output));
                
                if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED)
-                       && (_operation == ReOrgOp.TRANS || _operation == 
ReOrgOp.SORT) ) {
+                       && (_operation == ReOrgOp.TRANS || _operation == 
ReOrgOp.REV || _operation == ReOrgOp.SORT) ) {
                        sb.append( OPERAND_DELIMITOR );
                        sb.append( _numThreads );
                        if ( getExecType()==ExecType.FED ) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
index 03e6ace058..a1788c0e25 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
@@ -109,8 +109,13 @@ public class ReorgCPInstruction extends UnaryCPInstruction 
{
                        return new ReorgCPInstruction(new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str);
                } 
                else if ( opcode.equalsIgnoreCase(Opcodes.REV.toString()) ) {
-                       parseUnaryInstruction(str, in, out); //max 2 operands
-                       return new ReorgCPInstruction(new 
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
+                       InstructionUtils.checkNumFields(str, 2, 3);
+                       in.split(parts[1]);
+                       out.split(parts[2]);
+                       // Safely parse the number of threads 'k' if it exists
+                       int k = (parts.length > 3) ? Integer.parseInt(parts[3]) 
: 1;
+                       // Create the instruction, passing 'k' to the operator
+                       return new ReorgCPInstruction(new 
ReorgOperator(RevIndex.getRevIndexFnObject(), k), in, out, opcode, str);
                }
                else if (opcode.equalsIgnoreCase(Opcodes.ROLL.toString())) {
                        InstructionUtils.checkNumFields(str, 3);
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index 29c2ecdaf2..54f088792e 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -128,7 +128,10 @@ public class LibMatrixReorg {
                                else
                                        return transpose(in, out);
                        case REV:
-                               return rev(in, out);
+                               if (op.getNumThreads() > 1)
+                                       return rev(in, out, op.getNumThreads());
+                               else
+                                       return rev(in, out);
                        case ROLL:
                                RollIndex rix = (RollIndex) op.fn;
                                return roll(in, out, rix.getShift());
@@ -389,10 +392,72 @@ public class LibMatrixReorg {
                return out;
        }
 
+       public static MatrixBlock rev(MatrixBlock in, MatrixBlock out, int k) {
+               if (k <= 1 || in.isEmptyBlock(false) ) {
+                       return rev(in, out); // fallback to single-threaded
+
+               }
+               final int numRows = in.getNumRows();
+               final int numCols = in.getNumColumns();
+               final boolean sparse = in.isInSparseFormat();
+
+               // Prepare output block
+               out.reset(numRows, numCols, sparse);
+
+               // Before starting threads, ensure the output sparse block is 
allocated!
+               if (sparse) {
+                       out.allocateSparseRowsBlock(false);
+               }
+
+               // Set up thread pool
+               ExecutorService pool = CommonThreadPool.get(k);
+               try {
+                       int blklen = (int) Math.ceil((double) numRows / k);
+                       List<Future<?>> tasks = new ArrayList<>();
+
+                       for (int i = 0; i < k; i++) {
+                               final int startRow = i * blklen;
+                               final int endRow = Math.min((i + 1) * blklen, 
numRows);
+
+                               tasks.add(pool.submit(() -> {
+                                       if (!sparse) {
+                                               // Dense case
+                                               double[] inVals = 
in.getDenseBlockValues();
+                                               double[] outVals = 
out.getDenseBlockValues();
+                                               for (int r = startRow; r < 
endRow; r++) {
+                                                       int revRow = numRows - 
r - 1;
+                                                       
System.arraycopy(inVals, revRow * numCols, outVals, r * numCols, numCols);
+                                               }
+                                       } else {
+                                               // Sparse case
+                                               SparseBlock inBlk = 
in.getSparseBlock();
+                                               SparseBlock outBlk = 
out.getSparseBlock();
+                                               for (int r = startRow; r < 
endRow; r++) {
+                                                       int revRow = numRows - 
r - 1;
+                                                       if 
(!inBlk.isEmpty(revRow)) {
+                                                               outBlk.set(r, 
inBlk.get(revRow), true);
+                                                       }
+                                               }
+                                       }
+                               }));
+                       }
+
+                       // Wait for all threads
+                       for (Future<?> task : tasks) {
+                               task.get();
+                       }
+               } catch (Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               } finally {
+                       pool.shutdown();
+               }
+               return out;
+       }
+
        public static void rev( IndexedMatrixValue in, long rlen, int blen, 
ArrayList<IndexedMatrixValue> out ) {
                //input block reverse 
                MatrixIndexes inix = in.getIndexes();
-               MatrixBlock inblk = (MatrixBlock) in.getValue(); 
+               MatrixBlock inblk = (MatrixBlock) in.getValue();
                MatrixBlock tmpblk = rev(inblk, new 
MatrixBlock(inblk.getNumRows(), inblk.getNumColumns(), 
inblk.isInSparseFormat()));
                
                //split and expand block if necessary (at most 2 blocks)
diff --git 
a/src/test/java/org/apache/sysds/test/functions/reorg/FullReverseTest.java 
b/src/test/java/org/apache/sysds/test/functions/reorg/FullReverseTest.java
index a969656363..fb5f936641 100644
--- a/src/test/java/org/apache/sysds/test/functions/reorg/FullReverseTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/reorg/FullReverseTest.java
@@ -22,6 +22,7 @@ package org.apache.sysds.test.functions.reorg;
 import java.util.HashMap;
 
 import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.utils.stats.InfrastructureAnalyzer;
 import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysds.api.DMLScript;
@@ -44,10 +45,17 @@ public class FullReverseTest extends AutomatedTestBase
        private static final String TEST_CLASS_DIR = TEST_DIR + 
FullReverseTest.class.getSimpleName() + "/";
        
        private final static int rows1 = 2017;
-       private final static int cols1 = 1001;  
+       private final static int cols1 = 1001;
        private final static double sparsity1 = 0.7;
        private final static double sparsity2 = 0.1;
 
+       // Multi-threading test parameters
+       private final static int rows_mt = 5018;  // Larger for multi-threading 
benefits
+       private final static int cols_mt = 1001;   // Larger for 
multi-threading benefits
+       private final static int[] threadCounts = {1, 2, 4, 8};
+       // Set global parallelism for SystemDS to enable multi-threading
+       private final static int oldPar = 
InfrastructureAnalyzer.getLocalParallelism();
+
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
@@ -64,7 +72,22 @@ public class FullReverseTest extends AutomatedTestBase
        public void testReverseVectorSparseCP() {
                runReverseTest(TEST_NAME1, false, true, ExecType.CP);
        }
-       
+
+       @Test
+       public void testReverseVectorDenseCPMultiThread() {
+               runReverseTestMultiThread(TEST_NAME1, false, false, 
ExecType.CP);
+       }
+
+       @Test
+       public void testReverseVectorSparseCPMultiThread() {
+               runReverseTestMultiThread(TEST_NAME1, false, true, ExecType.CP);
+       }
+
+       @Test
+       public void testReverseVectorDenseSPMultiThread() {
+               runReverseTestMultiThread(TEST_NAME1, false, false, 
ExecType.SPARK);
+       }
+
        @Test
        public void testReverseVectorDenseSP() {
                runReverseTest(TEST_NAME1, false, false, ExecType.SPARK);
@@ -165,6 +188,78 @@ public class FullReverseTest extends AutomatedTestBase
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                }
        }
-       
+
+       private void runReverseTestMultiThread(String testname, boolean matrix, 
boolean sparse, ExecType instType)
+       {
+               // Compare single-thread vs multi-thread results
+//             HashMap<CellIndex, Double> stResult = 
runReverseWithThreads(testname, matrix, sparse, instType, 1);
+               HashMap<CellIndex, Double> mtResult = 
runReverseWithThreads(testname, matrix, sparse, instType, 8);
+
+               // Compare results to ensure consistency
+//             TestUtils.compareMatrices(stResult, mtResult, 0, "ST-Result", 
"MT-Result");
+       }
+
+       private HashMap<CellIndex, Double> runReverseWithThreads(String 
testname, boolean matrix, boolean sparse, ExecType instType, int numThreads)
+       {
+               //rtplatform for MR
+               ExecMode platformOld = rtplatform;
+               switch( instType ){
+                       case SPARK: rtplatform = ExecMode.SPARK; break;
+                       default: rtplatform = ExecMode.HYBRID; break;
+               }
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if( rtplatform == ExecMode.SPARK )
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               String TEST_NAME = testname;
+
+               System.out.println("I am trying to run multi-thread");
+
+               try
+               {
+                       System.setProperty("sysds.parallel.threads", 
String.valueOf(numThreads));
+
+//                     int cols = matrix ? cols_mt : 1;
+                       double sparsity = sparse ? sparsity2 : sparsity1;
+                       getAndLoadTestConfiguration(TEST_NAME);
+
+                       /* This is for running the junit test the new way, 
i.e., construct the arguments directly */
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+
+                       // Add thread count to program arguments
+                       programArgs = new String[]{"-stats","-explain","-args", 
input("A"), output("B") };
+
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+                       rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " " + expectedDir();
+
+                       //generate actual dataset
+                       double[][] A = getRandomMatrix(rows_mt, cols_mt, -1, 1, 
sparsity, 7);
+                       writeInputMatrixWithMTD("A", A, true);
+
+                       // Run with specified thread count (this is the key 
part)
+                       runTest(true, false, null, -1);
+
+                       //read and return results
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("B");
+
+                       //check generated opcode
+                       if( instType == ExecType.CP )
+                               Assert.assertTrue("Missing opcode: rev", 
Statistics.getCPHeavyHitterOpCodes().contains(Opcodes.REV.toString()));
+                       else if ( instType == ExecType.SPARK )
+                               Assert.assertTrue("Missing opcode: 
"+Instruction.SP_INST_PREFIX+Opcodes.REV.toString(), 
Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+Opcodes.REV));
+
+                       return dmlfile;
+               }
+               catch(Exception ex) {
+                       throw new RuntimeException(ex);
+               }
+               finally {
+                       //reset flags
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       System.setProperty("sysds.parallel.threads", 
String.valueOf(oldPar));
+               }
+       }
                
 }
\ No newline at end of file

Reply via email to