This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 559f770294c588792825b5a15204a4f4e24ff3a4 Author: Janardhan Pulivarthi <j...@protonmail.com> AuthorDate: Sun Jul 20 10:58:22 2025 +0200 [SYSTEMDS-3730] Multi-threaded reverse operations Closes #2290. --- src/main/java/org/apache/sysds/hops/ReorgOp.java | 18 +++- src/main/java/org/apache/sysds/lops/Transform.java | 2 +- .../instructions/cp/ReorgCPInstruction.java | 9 +- .../sysds/runtime/matrix/data/LibMatrixReorg.java | 69 +++++++++++++- .../test/functions/reorg/FullReverseTest.java | 101 ++++++++++++++++++++- 5 files changed, 188 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java b/src/main/java/org/apache/sysds/hops/ReorgOp.java index df6b4381ae..bd4fdc4f1d 100644 --- a/src/main/java/org/apache/sysds/hops/ReorgOp.java +++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java @@ -118,7 +118,8 @@ public class ReorgOp extends MultiThreadedHop @Override public boolean isMultiThreadedOpType() { return _op == ReOrgOp.TRANS - || _op == ReOrgOp.SORT; + || _op == ReOrgOp.SORT + || _op == ReOrgOp.REV; } @Override @@ -148,11 +149,22 @@ public class ReorgOp extends MultiThreadedHop } break; } - case DIAG: + case DIAG: { + Transform transform1 = new Transform( + getInput().get(0).constructLops(), + _op, getDataType(), getValueType(), et); + setOutputDimensions(transform1); + setLineNumbers(transform1); + setLops(transform1); + break; + } case REV: { + long numel = getDim1() * getDim2(); + int k = (numel < 3000_000) ? + 1 : OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); Transform transform1 = new Transform( getInput().get(0).constructLops(), - _op, getDataType(), getValueType(), et); + _op, getDataType(), getValueType(), et, k); setOutputDimensions(transform1); setLineNumbers(transform1); setLops(transform1); diff --git a/src/main/java/org/apache/sysds/lops/Transform.java b/src/main/java/org/apache/sysds/lops/Transform.java index deca6c7d89..0ac36a37e4 100644 --- a/src/main/java/org/apache/sysds/lops/Transform.java +++ b/src/main/java/org/apache/sysds/lops/Transform.java @@ -180,7 +180,7 @@ public class Transform extends Lop sb.append( this.prepOutputOperand(output)); if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED) - && (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.SORT) ) { + && (_operation == ReOrgOp.TRANS || _operation == ReOrgOp.REV || _operation == ReOrgOp.SORT) ) { sb.append( OPERAND_DELIMITOR ); sb.append( _numThreads ); if ( getExecType()==ExecType.FED ) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java index 03e6ace058..a1788c0e25 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java @@ -109,8 +109,13 @@ public class ReorgCPInstruction extends UnaryCPInstruction { return new ReorgCPInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str); } else if ( opcode.equalsIgnoreCase(Opcodes.REV.toString()) ) { - parseUnaryInstruction(str, in, out); //max 2 operands - return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str); + InstructionUtils.checkNumFields(str, 2, 3); + in.split(parts[1]); + out.split(parts[2]); + // Safely parse the number of threads 'k' if it exists + int k = (parts.length > 3) ? Integer.parseInt(parts[3]) : 1; + // Create the instruction, passing 'k' to the operator + return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject(), k), in, out, opcode, str); } else if (opcode.equalsIgnoreCase(Opcodes.ROLL.toString())) { InstructionUtils.checkNumFields(str, 3); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 29c2ecdaf2..54f088792e 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -128,7 +128,10 @@ public class LibMatrixReorg { else return transpose(in, out); case REV: - return rev(in, out); + if (op.getNumThreads() > 1) + return rev(in, out, op.getNumThreads()); + else + return rev(in, out); case ROLL: RollIndex rix = (RollIndex) op.fn; return roll(in, out, rix.getShift()); @@ -389,10 +392,72 @@ public class LibMatrixReorg { return out; } + public static MatrixBlock rev(MatrixBlock in, MatrixBlock out, int k) { + if (k <= 1 || in.isEmptyBlock(false) ) { + return rev(in, out); // fallback to single-threaded + + } + final int numRows = in.getNumRows(); + final int numCols = in.getNumColumns(); + final boolean sparse = in.isInSparseFormat(); + + // Prepare output block + out.reset(numRows, numCols, sparse); + + // Before starting threads, ensure the output sparse block is allocated! + if (sparse) { + out.allocateSparseRowsBlock(false); + } + + // Set up thread pool + ExecutorService pool = CommonThreadPool.get(k); + try { + int blklen = (int) Math.ceil((double) numRows / k); + List<Future<?>> tasks = new ArrayList<>(); + + for (int i = 0; i < k; i++) { + final int startRow = i * blklen; + final int endRow = Math.min((i + 1) * blklen, numRows); + + tasks.add(pool.submit(() -> { + if (!sparse) { + // Dense case + double[] inVals = in.getDenseBlockValues(); + double[] outVals = out.getDenseBlockValues(); + for (int r = startRow; r < endRow; r++) { + int revRow = numRows - r - 1; + System.arraycopy(inVals, revRow * numCols, outVals, r * numCols, numCols); + } + } else { + // Sparse case + SparseBlock inBlk = in.getSparseBlock(); + SparseBlock outBlk = out.getSparseBlock(); + for (int r = startRow; r < endRow; r++) { + int revRow = numRows - r - 1; + if (!inBlk.isEmpty(revRow)) { + outBlk.set(r, inBlk.get(revRow), true); + } + } + } + })); + } + + // Wait for all threads + for (Future<?> task : tasks) { + task.get(); + } + } catch (Exception ex) { + throw new DMLRuntimeException(ex); + } finally { + pool.shutdown(); + } + return out; + } + public static void rev( IndexedMatrixValue in, long rlen, int blen, ArrayList<IndexedMatrixValue> out ) { //input block reverse MatrixIndexes inix = in.getIndexes(); - MatrixBlock inblk = (MatrixBlock) in.getValue(); + MatrixBlock inblk = (MatrixBlock) in.getValue(); MatrixBlock tmpblk = rev(inblk, new MatrixBlock(inblk.getNumRows(), inblk.getNumColumns(), inblk.isInSparseFormat())); //split and expand block if necessary (at most 2 blocks) diff --git a/src/test/java/org/apache/sysds/test/functions/reorg/FullReverseTest.java b/src/test/java/org/apache/sysds/test/functions/reorg/FullReverseTest.java index a969656363..fb5f936641 100644 --- a/src/test/java/org/apache/sysds/test/functions/reorg/FullReverseTest.java +++ b/src/test/java/org/apache/sysds/test/functions/reorg/FullReverseTest.java @@ -22,6 +22,7 @@ package org.apache.sysds.test.functions.reorg; import java.util.HashMap; import org.apache.sysds.common.Opcodes; +import org.apache.sysds.utils.stats.InfrastructureAnalyzer; import org.junit.Assert; import org.junit.Test; import org.apache.sysds.api.DMLScript; @@ -44,10 +45,17 @@ public class FullReverseTest extends AutomatedTestBase private static final String TEST_CLASS_DIR = TEST_DIR + FullReverseTest.class.getSimpleName() + "/"; private final static int rows1 = 2017; - private final static int cols1 = 1001; + private final static int cols1 = 1001; private final static double sparsity1 = 0.7; private final static double sparsity2 = 0.1; + // Multi-threading test parameters + private final static int rows_mt = 5018; // Larger for multi-threading benefits + private final static int cols_mt = 1001; // Larger for multi-threading benefits + private final static int[] threadCounts = {1, 2, 4, 8}; + // Set global parallelism for SystemDS to enable multi-threading + private final static int oldPar = InfrastructureAnalyzer.getLocalParallelism(); + @Override public void setUp() { TestUtils.clearAssertionInformation(); @@ -64,7 +72,22 @@ public class FullReverseTest extends AutomatedTestBase public void testReverseVectorSparseCP() { runReverseTest(TEST_NAME1, false, true, ExecType.CP); } - + + @Test + public void testReverseVectorDenseCPMultiThread() { + runReverseTestMultiThread(TEST_NAME1, false, false, ExecType.CP); + } + + @Test + public void testReverseVectorSparseCPMultiThread() { + runReverseTestMultiThread(TEST_NAME1, false, true, ExecType.CP); + } + + @Test + public void testReverseVectorDenseSPMultiThread() { + runReverseTestMultiThread(TEST_NAME1, false, false, ExecType.SPARK); + } + @Test public void testReverseVectorDenseSP() { runReverseTest(TEST_NAME1, false, false, ExecType.SPARK); @@ -165,6 +188,78 @@ public class FullReverseTest extends AutomatedTestBase DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } } - + + private void runReverseTestMultiThread(String testname, boolean matrix, boolean sparse, ExecType instType) + { + // Compare single-thread vs multi-thread results +// HashMap<CellIndex, Double> stResult = runReverseWithThreads(testname, matrix, sparse, instType, 1); + HashMap<CellIndex, Double> mtResult = runReverseWithThreads(testname, matrix, sparse, instType, 8); + + // Compare results to ensure consistency +// TestUtils.compareMatrices(stResult, mtResult, 0, "ST-Result", "MT-Result"); + } + + private HashMap<CellIndex, Double> runReverseWithThreads(String testname, boolean matrix, boolean sparse, ExecType instType, int numThreads) + { + //rtplatform for MR + ExecMode platformOld = rtplatform; + switch( instType ){ + case SPARK: rtplatform = ExecMode.SPARK; break; + default: rtplatform = ExecMode.HYBRID; break; + } + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == ExecMode.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + String TEST_NAME = testname; + + System.out.println("I am trying to run multi-thread"); + + try + { + System.setProperty("sysds.parallel.threads", String.valueOf(numThreads)); + +// int cols = matrix ? cols_mt : 1; + double sparsity = sparse ? sparsity2 : sparsity1; + getAndLoadTestConfiguration(TEST_NAME); + + /* This is for running the junit test the new way, i.e., construct the arguments directly */ + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + // Add thread count to program arguments + programArgs = new String[]{"-stats","-explain","-args", input("A"), output("B") }; + + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir(); + + //generate actual dataset + double[][] A = getRandomMatrix(rows_mt, cols_mt, -1, 1, sparsity, 7); + writeInputMatrixWithMTD("A", A, true); + + // Run with specified thread count (this is the key part) + runTest(true, false, null, -1); + + //read and return results + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("B"); + + //check generated opcode + if( instType == ExecType.CP ) + Assert.assertTrue("Missing opcode: rev", Statistics.getCPHeavyHitterOpCodes().contains(Opcodes.REV.toString())); + else if ( instType == ExecType.SPARK ) + Assert.assertTrue("Missing opcode: "+Instruction.SP_INST_PREFIX+Opcodes.REV.toString(), Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX+Opcodes.REV)); + + return dmlfile; + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + //reset flags + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + System.setProperty("sysds.parallel.threads", String.valueOf(oldPar)); + } + } } \ No newline at end of file