This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 0664e1fd78 [MINOR] Add a few MatrixMult and asFrame Tests
0664e1fd78 is described below
commit 0664e1fd782dd34a5d3abce6db0f0a652bf9f0d3
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Tue Oct 17 14:02:43 2023 +0200
[MINOR] Add a few MatrixMult and asFrame Tests
---
src/test/java/org/apache/sysds/test/TestUtils.java | 2 +-
.../component/frame/FrameFromMatrixBlockTest.java | 51 +++++-
.../test/component/matrix/MatrixMultiplyTest.java | 179 +++++++++++++++++++++
3 files changed, 225 insertions(+), 7 deletions(-)
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java
b/src/test/java/org/apache/sysds/test/TestUtils.java
index 907c9adab8..9e866e5b33 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -1396,7 +1396,7 @@ public class TestUtils {
if(countErrors != 0)
fail(message + "\n" + countErrors + " values
are not in equal");
if(avgDistance > maxAveragePercentDistance)
- fail(message + "\nThe avg distance in bits: " +
avgDistance + " was higher than max: " + maxAveragePercentDistance);
+ fail(message + "\nThe avg distance in percent:
" + avgDistance + " was higher than max: " + maxAveragePercentDistance);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java
b/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java
index 76c7197322..bc0e242f9d 100644
---
a/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java
+++
b/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java
@@ -134,20 +134,48 @@ public class FrameFromMatrixBlockTest {
verifyEquivalence(mb, fb, ValueType.FP64);
}
+ @Test
+ public void random() {
+ MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 10, 0,
199, 1.0, 213);
+ FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, 1);
+ verifyEquivalence(mb, fb);
+ }
+
+ @Test
+ public void randomRounded() {
+ MatrixBlock mb =
TestUtils.ceil(TestUtils.generateTestMatrixBlock(100, 10, 0, 199, 1.0, 213));
+ FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, 1);
+ verifyEquivalence(mb, fb);
+ }
+
+ @Test
+ public void randomSparse() {
+ MatrixBlock mb =
TestUtils.ceil(TestUtils.generateTestMatrixBlock(100, 10, 0, 199, 0.1, 213));
+ FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, 1);
+ verifyEquivalence(mb, fb);
+ }
+
+ @Test
+ public void randomVerySparse() {
+ MatrixBlock mb =
TestUtils.ceil(TestUtils.generateTestMatrixBlock(100, 1000, 0, 199, 0.01, 213));
+ FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, 1);
+ verifyEquivalence(mb, fb);
+ }
+
@Test
public void timeChange() {
// MatrixBlock mb = TestUtils.generateTestMatrixBlock(64000,
2000, 1, 1, 0.5, 2340);
// for(int i = 0; i < 10; i++) {
- // Timing time = new Timing(true);
- // FrameFromMatrixBlock.convertToFrameBlock(mb,
ValueType.BOOLEAN, 1);
- // LOG.error(time.stop());
+ // Timing time = new Timing(true);
+ // FrameFromMatrixBlock.convertToFrameBlock(mb,
ValueType.BOOLEAN, 1);
+ // LOG.error(time.stop());
// }
// for(int i = 0; i < 10; i++) {
- // Timing time = new Timing(true);
- // FrameFromMatrixBlock.convertToFrameBlock(mb,
ValueType.BOOLEAN, 16);
- // LOG.error(time.stop());
+ // Timing time = new Timing(true);
+ // FrameFromMatrixBlock.convertToFrameBlock(mb,
ValueType.BOOLEAN, 16);
+ // LOG.error(time.stop());
// }
// for(int i = 0; i < 10; i ++){
@@ -176,6 +204,17 @@ public class FrameFromMatrixBlockTest {
}
+ private void verifyEquivalence(MatrixBlock mb, FrameBlock fb) {
+ int nRow = mb.getNumRows();
+ int nCol = mb.getNumColumns();
+ assertEquals(mb.getNumColumns(), fb.getSchema().length);
+
+ for(int i = 0; i < nRow; i++)
+ for(int j = 0; j < nCol; j++)
+ assertEquals(i + " " + j, mb.getValue(i, j),
fb.getDouble(i, j), 0.0000001);
+
+ }
+
private MatrixBlock mock(MatrixBlock m) {
MatrixBlock ret = new MatrixBlock(m.getNumRows(),
m.getNumColumns(),
new DenseBlockFP64Mock(new int[] {m.getNumRows(),
m.getNumColumns()}, m.getDenseBlockValues()));
diff --git
a/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
b/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
new file mode 100644
index 0000000000..d35b47a4c6
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrix;
+
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(value = Parameterized.class)
+public class MatrixMultiplyTest {
+ protected static final Log LOG =
LogFactory.getLog(MatrixMultiplyTest.class.getName());
+
+ // left side
+ private final MatrixBlock left;
+ // right side
+ private final MatrixBlock right;
+ // expected result
+ private final MatrixBlock exp;
+ // parallelization degree
+ private final int k;
+
+ public MatrixMultiplyTest(int i, int j, int k, double s, double s2, int
p) {
+ try {
+ this.left =
TestUtils.ceil(TestUtils.generateTestMatrixBlock(i, j, -10, 10, i == 1 && j ==
1 ? 1 : s, 13));
+ this.right =
TestUtils.ceil(TestUtils.generateTestMatrixBlock(j, k, -10, 10, k == 1 && k ==
1 ? 1 : s2, 14));
+
+ this.exp = multiply(left, right, 1);
+ this.k = p;
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Parameters
+ public static Collection<Object[]> data() {
+
+ List<Object[]> tests = new ArrayList<>();
+ try {
+ double[] sparsities = new double[] {0.001, 0.1, 0.5};
+ int[] is = new int[] {1, 3, 1024};
+ int[] js = new int[] {1, 3, 1024};
+ int[] ks = new int[] {1, 3, 1024};
+ int[] par = new int[] {1, 4};
+
+ for(int s = 0; s < sparsities.length; s++) {
+ for(int s2 = 0; s2 < sparsities.length; s2++) {
+ for(int p = 0; p < par.length; p++) {
+ for(int i = 0; i < is.length;
i++) {
+ for(int j = 0; j <
js.length; j++) {
+ for(int k = 0;
k < ks.length; k++) {
+
tests.add(new Object[] {is[i], js[j], ks[k], sparsities[s], sparsities[s2],
par[p]});
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail("failed constructing tests");
+ }
+
+ return tests;
+ }
+
+ @Test
+ public void testMultiplicationAsIs() {
+ test(left, right);
+ }
+
+ @Test
+ public void testLeftForceDense() {
+ left.sparseToDense();
+ test(left, right);
+ }
+
+ @Test
+ public void testRightForceDense() {
+ right.sparseToDense();
+ test(left, right);
+ }
+
+ @Test
+ public void testBothForceDense() {
+ left.sparseToDense();
+ right.sparseToDense();
+ test(left, right);
+ }
+
+ @Test
+ public void testLeftForceSparse() {
+ left.denseToSparse(true);
+ test(left, right);
+ }
+
+ @Test
+ public void testRightForceSparse() {
+ right.denseToSparse(true);
+ test(left, right);
+ }
+
+ @Test
+ public void testBothForceSparse() {
+ left.denseToSparse(true);
+ right.denseToSparse(true);
+ test(left, right);
+ }
+
+ private void test(MatrixBlock a, MatrixBlock b) {
+ try {
+ MatrixBlock ret = multiply(a, b, k);
+
+ boolean sparseLeft = a.isInSparseFormat();
+ boolean sparseRight = b.isInSparseFormat();
+ boolean sparseOut = exp.isInSparseFormat();
+ String sparseErrMessage = "SparseLeft:" + sparseLeft +
" SparseRight: " + sparseRight + " SparseOut:"
+ + sparseOut;
+ String sizeErrMessage = size(a) + " " + size(b) + " "
+ size(exp);
+
+ String totalMessage = "\n\n" + sizeErrMessage + "\n" +
sparseErrMessage;
+
+ if(ret.getNumRows() * ret.getNumColumns() < 1000) {
+ totalMessage += "\n\nExp" + exp;
+ totalMessage += "\n\nAct" + ret;
+ }
+
+ TestUtils.compareMatricesPercentageDistance(exp, ret,
0.999, 0.99999, totalMessage, false);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
+ private static String size(MatrixBlock a) {
+ return a.getNumRows() + "x" + a.getNumColumns() + "n" +
a.getNonZeros();
+ }
+
+ private static MatrixBlock multiply(MatrixBlock a, MatrixBlock b, int
k) {
+ AggregateOperator agg = new AggregateOperator(0,
Plus.getPlusFnObject());
+ AggregateBinaryOperator mult = new
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg, k);
+ return a.aggregateBinaryOperations(a, b, mult);
+ }
+}