This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 01ec1d7afd [SYSTEMDS-3907] lmDS Algorithm Test for OOC Backend
01ec1d7afd is described below
commit 01ec1d7afd14c68aa57499a79de4d6393bb32d58
Author: Janardhan Pulivarthi <[email protected]>
AuthorDate: Sat Oct 18 10:42:55 2025 +0200
[SYSTEMDS-3907] lmDS Algorithm Test for OOC Backend
Closes #2338.
---
.../sysds/hops/rewrite/RewriteInjectOOCTee.java | 21 +++-
.../sysds/runtime/matrix/data/LibCommonsMath.java | 2 +-
.../sysds/runtime/matrix/data/LibMatrixMult.java | 3 +-
.../apache/sysds/test/functions/ooc/lmDSTest.java | 125 +++++++++++++++++++++
src/test/scripts/functions/ooc/lmDS.dml | 28 +++++
5 files changed, 176 insertions(+), 3 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
index 6acd314d9a..f6033a805a 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java
@@ -24,6 +24,7 @@ import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.ReorgOp;
import java.util.ArrayList;
import java.util.HashMap;
@@ -138,7 +139,7 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
if (DMLScript.USE_OOC
&& hop.getDataType().isMatrix()
&& !HopRewriteUtils.isData(hop, OpOpData.TEE)
- && hop.getParent().size() > 1)
+ && hop.getParent().size() > 1)
{
rewriteCandidates.add(hop);
}
@@ -174,4 +175,22 @@ public class RewriteInjectOOCTee extends HopRewriteRule {
handledHop.put(sharedInput.getHopID(), teeOp);
rewrittenHops.add(sharedInput.getHopID());
}
+
+ @SuppressWarnings("unused")
+ private boolean isSelfTranposePattern (Hop hop) {
+ boolean hasTransposeConsumer = false; // t(X)
+ boolean hasMatrixMultiplyConsumer = false; // %*%
+
+ for (Hop parent: hop.getParent()) {
+ if (parent instanceof ReorgOp) {
+ if
(HopRewriteUtils.isTransposeOperation(parent)) {
+ hasTransposeConsumer = true;
+ }
+ }
+ else if (HopRewriteUtils.isMatrixMultiply(parent)) {
+ hasMatrixMultiplyConsumer = true;
+ }
+ }
+ return hasTransposeConsumer && hasMatrixMultiplyConsumer;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
index 4dc85755c6..5fe837ed1f 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
@@ -201,7 +201,7 @@ public class LibCommonsMath
* @param in2 matrix object 2
* @return matrix block
*/
- private static MatrixBlock computeSolve(MatrixBlock in1, MatrixBlock
in2) {
+ public static MatrixBlock computeSolve(MatrixBlock in1, MatrixBlock
in2) {
//convert to commons math BlockRealMatrix instead of
Array2DRowRealMatrix
//to avoid unnecessary conversion as QR internally creates a
BlockRealMatrix
BlockRealMatrix matrixInput =
DataConverter.convertToBlockRealMatrix(in1);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index af702cb7fa..5753fbbadb 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -452,8 +452,9 @@ public class LibMatrixMult
//
"("+mV.isInSparseFormat()+","+mV.getNumRows()+","+mV.getNumColumns()+","+mV.getNonZeros()+")
in "+time.stop());
}
- public static void matrixMultTransposeSelf( MatrixBlock m1, MatrixBlock
ret, boolean leftTranspose ) {
+ public static MatrixBlock matrixMultTransposeSelf( MatrixBlock m1,
MatrixBlock ret, boolean leftTranspose ) {
matrixMultTransposeSelf(m1, ret, leftTranspose, true);
+ return ret;
}
public static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock
ret, boolean leftTranspose, boolean copyToLowerTriangle){
diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java
b/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java
new file mode 100644
index 0000000000..258f6a0531
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/ooc/lmDSTest.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.ooc;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.io.MatrixWriter;
+import org.apache.sysds.runtime.io.MatrixWriterFactory;
+import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public class lmDSTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "lmDS";
+ private final static String TEST_DIR = "functions/ooc/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
lmDSTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-10;
+ private static final String INPUT_NAME = "X";
+ private static final String INPUT_NAME2 = "y";
+ private static final String OUTPUT_NAME = "R";
+
+ private final static int rows = 100000;
+ private final static int cols_wide = 500; //TODO larger than 1000
+ private final static int cols_skinny = 10;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
+ addTestConfiguration(TEST_NAME1, config);
+ }
+
+ @Test
+ @Ignore //FIXME
+ public void testlmDS1() {
+ runMatrixVectorMultiplicationTest(cols_wide);
+ }
+
+ @Test
+ @Ignore //FIXME
+ public void testlmDS2() {
+ runMatrixVectorMultiplicationTest(cols_skinny);
+ }
+
+ private void runMatrixVectorMultiplicationTest(int cols)
+ {
+ Types.ExecMode platformOld =
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+ try
+ {
+ getAndLoadTestConfiguration(TEST_NAME1);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ programArgs = new String[]{"-explain", "-stats", "-ooc",
+ "-args", input(INPUT_NAME),
input(INPUT_NAME2), output(OUTPUT_NAME)};
+
+ // 1. Generate the data in-memory as MatrixBlock objects
+ double[][] X_data = getRandomMatrix(rows, cols, 0, 1,
1.0, 7);
+ double[][] y_data = getRandomMatrix(rows, 1, 0, 1, 1.0,
3);
+
+ // 2. Convert the double arrays to MatrixBlock objects
+ MatrixBlock X_mb =
DataConverter.convertToMatrixBlock(X_data);
+ MatrixBlock y_mb =
DataConverter.convertToMatrixBlock(y_data);
+
+ // 3. Create a binary matrix writer
+ MatrixWriter writer =
MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);
+
+ // 4. Write matrix A to a binary SequenceFile
+ writer.writeMatrixToHDFS(X_mb, input(INPUT_NAME), rows,
cols, 1000, X_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"),
Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, cols, 1000,
X_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+ // 5. Write vector x to a binary SequenceFile
+ writer.writeMatrixToHDFS(y_mb, input(INPUT_NAME2),
rows, 1, 1000, y_mb.getNonZeros());
+ HDFSTool.writeMetaDataFile(input(INPUT_NAME2 + ".mtd"),
Types.ValueType.FP64,
+ new MatrixCharacteristics(rows, 1, 1000,
y_mb.getNonZeros()), Types.FileFormat.BINARY);
+
+ runTest(true, false, null, -1);
+ MatrixBlock C = DataConverter.readMatrixFromHDFS(
+ output(OUTPUT_NAME), Types.FileFormat.BINARY,
rows, cols, 1000, 1000);
+
+ //expected results
+ MatrixBlock xtx =
LibMatrixMult.matrixMultTransposeSelf(X_mb, new MatrixBlock(cols,cols,false),
true);
+ MatrixBlock xt = LibMatrixReorg.transpose(X_mb);
+ MatrixBlock xty = LibMatrixMult.matrixMult(xt, y_mb);
+ MatrixBlock ret = LibCommonsMath.computeSolve(xtx, xty);
+ for(int i = 0; i < cols; i++)
+ Assert.assertEquals(ret.get(i, 0), C.get(i,0),
eps);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/ooc/lmDS.dml
b/src/test/scripts/functions/ooc/lmDS.dml
new file mode 100644
index 0000000000..930d956c50
--- /dev/null
+++ b/src/test/scripts/functions/ooc/lmDS.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1)
+y = read($2)
+
+XtX = t(X) %*% X; # 500 x 500
+Xty = t(X) %*% y; # 500 x 1
+R = solve(XtX, Xty)
+write(R, $3, format="binary")