This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 5015f63a79 [SYSTEMDS-3709] Additional tests for UDF backwards
compatibility
5015f63a79 is described below
commit 5015f63a7980f36e832bdffcbebba575cf8ddd62
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Jun 7 09:44:45 2024 +0200
[SYSTEMDS-3709] Additional tests for UDF backwards compatibility
This patch adds tests for the old SystemML UDF MultiInputCbind,
ensuring the related DML script is properly compiled to an nary cbind
and if the inputs are vectors and are reshaped to vectors, we also
eliminate the unnecessary reshape.
---
.../matrix/UDFBackwardsCompatibilityTest.java | 48 ++++++++++++++++++----
...owClassMeetTest.dml => MultiInputCbindTest.dml} | 10 ++++-
.../functions/binary/matrix/RowClassMeetTest.dml | 2 +-
3 files changed, 48 insertions(+), 12 deletions(-)
diff --git
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
index f4961efc55..44cca625ea 100644
---
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
@@ -19,16 +19,20 @@
package org.apache.sysds.test.functions.binary.matrix;
+import org.junit.Assert;
import org.junit.Test;
+
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
public class UDFBackwardsCompatibilityTest extends AutomatedTestBase
{
private final static String TEST_NAME1 = "RowClassMeetTest";
+ private final static String TEST_NAME2 = "MultiInputCbindTest";
private final static String TEST_DIR = "functions/binary/matrix/";
private final static String TEST_CLASS_DIR = TEST_DIR +
UDFBackwardsCompatibilityTest.class.getSimpleName() + "/";
@@ -44,29 +48,46 @@ public class UDFBackwardsCompatibilityTest extends
AutomatedTestBase
public void setUp() {
addTestConfiguration( TEST_NAME1,
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new
String[] { "C" }) );
+ addTestConfiguration( TEST_NAME2,
+ new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new
String[] { "C" }) );
}
@Test
public void testRowClassMeetDenseDense() {
- runUDFTest(TEST_NAME1, false, false, ExecType.CP);
+ runUDFTest(TEST_NAME1, false, false, false, false, ExecType.CP);
}
@Test
public void testRowClassMeetDenseSparse() {
- runUDFTest(TEST_NAME1, false, true, ExecType.CP);
+ runUDFTest(TEST_NAME1, false, true, false, false, ExecType.CP);
}
@Test
public void testRowClassMeetSparseDense() {
- runUDFTest(TEST_NAME1, true, false, ExecType.CP);
+ runUDFTest(TEST_NAME1, true, false, false, false, ExecType.CP);
}
@Test
public void testRowClassMeetSparseSparse() {
- runUDFTest(TEST_NAME1, true, true, ExecType.CP);
+ runUDFTest(TEST_NAME1, true, true, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testMultiInputCBindDenseDenseMatrixMatrix() {
+ runUDFTest(TEST_NAME2, false, false, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testMultiInputCBindDenseDenseMatrixVector() {
+ runUDFTest(TEST_NAME2, false, false, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testMultiInputCBindDenseDenseVectorVector() {
+ runUDFTest(TEST_NAME2, false, false, true, true, ExecType.CP);
}
- private void runUDFTest(String testname, boolean sparseM1, boolean
sparseM2, ExecType instType)
+ private void runUDFTest(String testname, boolean sparseM1, boolean
sparseM2, boolean vectorData, boolean vectorize, ExecType instType)
{
ExecMode platformOld = setExecMode(instType);
String TEST_NAME = testname;
@@ -76,18 +97,27 @@ public class UDFBackwardsCompatibilityTest extends
AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[]{"-explain","-args",
input("A"), input("B"), output("C")};
+ programArgs = new String[]{"-stats", "-explain","-args",
+ input("A"), input("B"),
String.valueOf(vectorize).toUpperCase(), output("C")};
//generate actual dataset
+ int nr = vectorData ? rows*cols : rows;
+ int nc = vectorData ? 1 : cols;
+
double[][] A = TestUtils.round(
- getRandomMatrix(rows, cols, 0, 10,
sparseM1?sparsity2:sparsity1, 7));
+ getRandomMatrix(nr, nc, 0, 10,
sparseM1?sparsity2:sparsity1, 7));
writeInputMatrixWithMTD("A", A, false);
double[][] B = TestUtils.round(
- getRandomMatrix(rows, cols, 0, 10,
sparseM2?sparsity2:sparsity1, 3));
+ getRandomMatrix(nr, nc, 0, 10,
sparseM2?sparsity2:sparsity1, 3));
writeInputMatrixWithMTD("B", B, false);
//run test case
- runTest(true, false, null, -1);
+ runTest(true, false, null, -1);
+
+ if( TEST_NAME.equals(TEST_NAME2) ) //check nary cbind
+ Assert.assertEquals(1,
Statistics.getCPHeavyHitterCount("cbind"));
+ if( vectorData && vectorize ) //check eliminated reshape
+
Assert.assertFalse(heavyHittersContainsString("rshape"));
}
finally {
rtplatform = platformOld;
diff --git a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
b/src/test/scripts/functions/binary/matrix/MultiInputCbindTest.dml
similarity index 86%
copy from src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
copy to src/test/scripts/functions/binary/matrix/MultiInputCbindTest.dml
index 9975f8d99d..77445023f6 100644
--- a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
+++ b/src/test/scripts/functions/binary/matrix/MultiInputCbindTest.dml
@@ -21,6 +21,12 @@
A = read($1);
B = read($2);
-[C,N] = rowClassMeet(A, B);
-write(C, $3);
+
+if( as.logical($3) ) {
+ A = matrix(A, rows=length(A), cols=1)
+ B = matrix(B, rows=length(B), cols=1)
+}
+
+R = cbind(cbind(A, B), A);
+write(R, $4);
diff --git a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
b/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
index 9975f8d99d..f2d9da3ae8 100644
--- a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
+++ b/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
@@ -22,5 +22,5 @@
A = read($1);
B = read($2);
[C,N] = rowClassMeet(A, B);
-write(C, $3);
+write(C, $4);