This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new baacab4e59 [SYSTEMDS-3908] Fix OOC matmult compilation w/ transpose 
rewrite
baacab4e59 is described below

commit baacab4e5959b3492419481a47552b6eb4969562
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Oct 18 13:10:12 2025 +0200

    [SYSTEMDS-3908] Fix OOC matmult compilation w/ transpose rewrite
    
    In CP, we rewrite t(X)%*%y to t(t(y)%*%X) if the two transposes are
    much smaller and especially if they are vectors because vector transpose
    is a meta data operation. However, if y is an OOC stream, this rewrite
    destroyed the pipeline (and incomplete exception handling and other
    primitives) made the resulting issue hard to debug.
---
 src/main/java/org/apache/sysds/hops/AggBinaryOp.java             | 3 ++-
 .../java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java  | 3 ++-
 src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java  | 9 +++------
 3 files changed, 7 insertions(+), 8 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 8685524a3f..e3cf95e573 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -624,7 +624,8 @@ public class AggBinaryOp extends MultiThreadedHop {
 
                //Handle Y or actualY for transpose
                Lop yLop = isYTransposed ? actualY.constructLops() : 
Y.constructLops();
-               ExecType inputReorgExecType = (Y.hasFederatedOutput()) ? 
ExecType.FED : ExecType.CP;
+               ExecType inputReorgExecType = (Y.hasFederatedOutput()) ? 
ExecType.FED :
+                       (et==ExecType.OOC) ? ExecType.OOC : ExecType.CP;
 
                //right vector transpose
                Lop tY = (yLop instanceof Transform && 
((Transform)yLop).getOp() == ReOrgOp.TRANS) ?
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
index f6033a805a..2e849f74fb 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
@@ -139,7 +139,8 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
                if (DMLScript.USE_OOC 
                        && hop.getDataType().isMatrix()
                        && !HopRewriteUtils.isData(hop, OpOpData.TEE)
-                       && hop.getParent().size() > 1)
+                       && hop.getParent().size() > 1
+                       && isSelfTranposePattern(hop)) //FIXME remove
                {
                        rewriteCandidates.add(hop);
                }
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java 
b/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java
index 258f6a0531..e6a147775f 100644
--- a/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java
@@ -33,7 +33,6 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Assert;
-import org.junit.Ignore;
 import org.junit.Test;
 
 import java.io.IOException;
@@ -47,7 +46,7 @@ public class lmDSTest extends AutomatedTestBase {
        private static final String INPUT_NAME2 = "y";
        private static final String OUTPUT_NAME = "R";
 
-       private final static int rows = 100000;
+       private final static int rows = 10000;
        private final static int cols_wide = 500; //TODO larger than 1000
        private final static int cols_skinny = 10;
 
@@ -59,13 +58,11 @@ public class lmDSTest extends AutomatedTestBase {
        }
 
        @Test
-       @Ignore //FIXME
        public void testlmDS1() {
                runMatrixVectorMultiplicationTest(cols_wide);
        }
 
        @Test
-       @Ignore //FIXME
        public void testlmDS2() {
                runMatrixVectorMultiplicationTest(cols_skinny);
        }
@@ -80,7 +77,7 @@ public class lmDSTest extends AutomatedTestBase {
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
                        programArgs = new String[]{"-explain", "-stats", "-ooc",
-                                       "-args", input(INPUT_NAME), 
input(INPUT_NAME2), output(OUTPUT_NAME)};
+                               "-args", input(INPUT_NAME), input(INPUT_NAME2), 
output(OUTPUT_NAME)};
 
                        // 1. Generate the data in-memory as MatrixBlock objects
                        double[][] X_data = getRandomMatrix(rows, cols, 0, 1, 
1.0, 7);
@@ -105,7 +102,7 @@ public class lmDSTest extends AutomatedTestBase {
 
                        runTest(true, false, null, -1);
                        MatrixBlock C = DataConverter.readMatrixFromHDFS(
-                               output(OUTPUT_NAME), Types.FileFormat.BINARY, 
rows, cols, 1000, 1000);
+                               output(OUTPUT_NAME), Types.FileFormat.BINARY, 
cols, 1, 1000, 1000);
                        
                        //expected results
                        MatrixBlock xtx = 
LibMatrixMult.matrixMultTransposeSelf(X_mb, new MatrixBlock(cols,cols,false), 
true);

Reply via email to