This is an automated email from the ASF dual-hosted git repository.
ywcb00 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c50c769153 [SYSTEMDS-3948] Implement Row-wise Sparsity Estimator
c50c769153 is described below
commit c50c769153de7a7ed4aa36a9c9fa86d3bd3c569e
Author: ywcb00 <[email protected]>
AuthorDate: Thu May 21 09:28:37 2026 +0200
[SYSTEMDS-3948] Implement Row-wise Sparsity Estimator
This commit implements the row-wise sparsity estimator and adds respective
test cases.
Closes #2466.
---
.../apache/sysds/hops/estim/EstimatorRowWise.java | 356 +++++++++++++++++++++
.../test/component/estim/OpBindChainTest.java | 116 ++++---
.../sysds/test/component/estim/OpBindTest.java | 105 +++---
.../test/component/estim/OpElemWChainTest.java | 98 +++---
.../sysds/test/component/estim/OpElemWTest.java | 96 +++---
.../sysds/test/component/estim/OpSingleTest.java | 166 ++++++----
.../test/component/estim/OuterProductTest.java | 107 +++----
.../test/component/estim/SelfProductTest.java | 146 ++++-----
.../component/estim/SquaredProductChainTest.java | 132 ++++----
.../test/component/estim/SquaredProductTest.java | 138 ++++----
10 files changed, 917 insertions(+), 543 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/estim/EstimatorRowWise.java
b/src/main/java/org/apache/sysds/hops/estim/EstimatorRowWise.java
new file mode 100644
index 0000000000..0cee03c8df
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/estim/EstimatorRowWise.java
@@ -0,0 +1,356 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.estim;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.data.SparseRow;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+
+import java.util.stream.DoubleStream;
+import java.util.stream.IntStream;
+
+/**
+ * This estimator implements an approach based on row-wise sparsity estimation,
+ * introduced in
+ * Lin, Chunxu, Wensheng Luo, Yixiang Fang, Chenhao Ma, Xilin Liu and Yuchi Ma:
+ * On Efficient Large Sparse Matrix Chain Multiplication.
+ * Proceedings of the ACM on Management of Data 2 (2024): 1 - 27.
+ */
+public class EstimatorRowWise extends SparsityEstimator {
+ @Override
+ public DataCharacteristics estim(MMNode root) {
+ estimInternChain(root);
+ double sparsity =
DoubleStream.of((double[])root.getSynopsis()).average().orElse(0);
+
+ DataCharacteristics outputCharacteristics =
deriveOutputCharacteristics(root, sparsity);
+ return root.setDataCharacteristics(outputCharacteristics);
+ }
+
+ @Override
+ public double estim(MatrixBlock m1, MatrixBlock m2) {
+ return estim(m1, m2, OpCode.MM);
+ }
+
+ @Override
+ public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
+ if( isExactMetadataOp(op, m1.getNumColumns()) ) {
+ return estimExactMetaData(m1.getDataCharacteristics(),
+ m2.getDataCharacteristics(), op).getSparsity();
+ }
+
+ double[] rsOut = estimIntern(m1, m2, op);
+ return DoubleStream.of(rsOut).average().orElse(0);
+ }
+
+ @Override
+ public double estim(MatrixBlock m1, OpCode op) {
+ if( isExactMetadataOp(op, m1.getNumColumns()) )
+ return estimExactMetaData(m1.getDataCharacteristics(),
null, op).getSparsity();
+
+ double[] rsOut = estimIntern(m1, op);
+ return DoubleStream.of(rsOut).average().orElse(0);
+ }
+
+ private double[] estimInternChain(MMNode node) {
+ return estimInternChain(node, null, null);
+ }
+
+ private double[] estimInternChain(MMNode node, double[]
rsRightNeighbor, OpCode opRightNeighbor) {
+ double[] rsOut;
+ if(node.isLeaf()) {
+ MatrixBlock mb = node.getData();
+ if(rsRightNeighbor != null)
+ rsOut = estimIntern(mb, rsRightNeighbor,
opRightNeighbor);
+ else
+ rsOut = getRowWiseSparsityVector(mb);
+ }
+ else {
+ MMNode nodeLeft = node.getLeft();
+ MMNode nodeRight = node.getRight();
+ switch(node.getOp()) {
+ case MM:
+ double[] rsRightMM =
estimInternChain(nodeRight, rsRightNeighbor, opRightNeighbor);
+ rsOut = estimInternChain(nodeLeft,
rsRightMM, node.getOp());
+ break;
+ case CBIND:
+ /**
+ * NOTE: considering the current node
as new DAG for estimation (cut), since the row sparsity of
+ * the right neighbor cannot be
aggregated into a cbind operation when having only row sparsity vectors
+ */
+ double[] rsLeftCBind =
estimInternChain(nodeLeft);
+ double[] rsRightCBind =
estimInternChain(nodeRight);
+ double[] rsCBind =
estimInternCBind(rsLeftCBind, rsRightCBind);
+ if(rsRightNeighbor != null) {
+ rsOut =
estimInternMMFallback(rsCBind, rsRightNeighbor);
+ if(opRightNeighbor != OpCode.MM)
+ throw new
NotImplementedException("Fallback sparsity estimation has only been " +
+ "considered for
MM operation w/ right neighbor yet.");
+ }
+ else
+ rsOut = rsCBind;
+ break;
+ case RBIND:
+ /**
+ * NOTE: considering the current node
as new DAG for estimation (cut), since the row sparsity of
+ * the right neighbor cannot be
aggregated into an rbind operation when having only row sparsity vectors
+ */
+ double[] rsLeftRBind =
estimInternChain(nodeLeft);
+ double[] rsRightRBind =
estimInternChain(nodeRight);
+ double[] rsRBind =
estimInternRBind(rsLeftRBind, rsRightRBind);
+ if(rsRightNeighbor != null) {
+ rsOut =
estimInternMMFallback(rsRBind, rsRightNeighbor);
+ if(opRightNeighbor != OpCode.MM)
+ throw new
NotImplementedException("Fallback sparsity estimation has only been " +
+ "considered for
MM operation w/ right neighbor yet.");
+ }
+ else
+ rsOut = rsRBind;
+ break;
+ case PLUS:
+ /**
+ * NOTE: considering the current node
as new DAG for estimation (cut), since the row sparsity of
+ * the right neighbor cannot be
aggregated into an element-wise operation when having only row sparsity vectors
+ */
+ double[] rsLeftPlus =
estimInternChain(nodeLeft);
+ double[] rsRightPlus =
estimInternChain(nodeRight);
+ double[] rsPlus =
estimInternPlus(rsLeftPlus, rsRightPlus);
+ if(rsRightNeighbor != null) {
+ rsOut =
estimInternMMFallback(rsPlus, rsRightNeighbor);
+ if(opRightNeighbor != OpCode.MM)
+ throw new
NotImplementedException("Fallback sparsity estimation has only been " +
+ "considered for
MM operation w/ right neighbor yet.");
+ }
+ else
+ rsOut = rsPlus;
+ break;
+ case MULT:
+ /**
+ * NOTE: considering the current node
as new DAG for estimation (cut), since the row sparsity of
+ * the right neighbor cannot be
aggregated into an element-wise operation when having only row sparsity vectors
+ */
+ double[] rsLeftMult =
estimInternChain(nodeLeft);
+ double[] rsRightMult =
estimInternChain(nodeRight);
+ double[] rsMult =
estimInternMult(rsLeftMult, rsRightMult);
+ if(rsRightNeighbor != null) {
+ rsOut =
estimInternMMFallback(rsMult, rsRightNeighbor);
+ if(opRightNeighbor != OpCode.MM)
+ throw new
NotImplementedException("Fallback sparsity estimation has only been " +
+ "considered for
MM operation w/ right neighbor yet.");
+ }
+ else
+ rsOut = rsMult;
+ break;
+ default:
+ throw new
NotImplementedException("Chain estimation for operator " +
node.getOp().toString() +
+ " is not supported yet.");
+ }
+ }
+ node.setSynopsis(rsOut);
+ node.setDataCharacteristics(deriveOutputCharacteristics(node,
DoubleStream.of(rsOut).average().orElse(0)));
+ return rsOut;
+ }
+
+ private double[] estimIntern(MatrixBlock m1, MatrixBlock m2, OpCode op)
{
+ double[] rsM2 = getRowWiseSparsityVector(m2);
+ return estimIntern(m1, rsM2, op);
+ }
+
+ private double[] estimIntern(MatrixBlock m1, double[] rsM2, OpCode op) {
+ switch(op) {
+ case MM:
+ return estimInternMM(m1, rsM2);
+ case CBIND:
+ return
estimInternCBind(getRowWiseSparsityVector(m1), rsM2);
+ case RBIND:
+ return
estimInternRBind(getRowWiseSparsityVector(m1), rsM2);
+ case PLUS:
+ return
estimInternPlus(getRowWiseSparsityVector(m1), rsM2);
+ case MULT:
+ return
estimInternMult(getRowWiseSparsityVector(m1), rsM2);
+ default:
+ throw new NotImplementedException("Sparsity
estimation for operation " + op.toString() + " not supported yet.");
+ }
+ }
+
+ private double[] estimIntern(MatrixBlock mb, OpCode op) {
+ switch(op) {
+ case DIAG:
+ return estimInternDiag(mb);
+ default:
+ throw new NotImplementedException("Sparsity
estimation for operation " + op.toString() + " not supported yet.");
+ }
+ }
+
+ /**
+ * Corresponds to Algorithm 1 in the publication
+ */
+ private double[] estimInternMM(MatrixBlock m1, double[] rsM2) {
+ double[] rsOut = new double[m1.getNumRows()];
+ for(int rIdx = 0; rIdx < m1.getNumRows(); rIdx++) {
+ double currentVal = 1;
+ for(int cIdx : getNonZeroColumnIndices(m1, rIdx)) {
+ currentVal *= 1.0 - rsM2[cIdx];
+ }
+ rsOut[rIdx] = 1 - currentVal;
+ }
+ return rsOut;
+ }
+
+ /**
+ * NOTE: fallback estimate using the uniform estimator (aka
average-case estimator, Naive Bayes estimator) for
+ * the case when we are limited to the row sparsity vectors of both
inputs
+ * NOTE: Considering the average of the second matrix would probably
not be far off while saving computing time
+ */
+ private double[] estimInternMMFallback(double[] rsM1, double[] rsM2) {
+ double[] rsOut = new double[rsM1.length];
+ for(int i = 0; i < rsM1.length; i++) {
+ double rsM1i = rsM1[i];
+ if(rsM1i == 0) {
+ rsOut[i] = 0;
+ }
+ else {
+ double currentVal = 1;
+ for(int j = 0; j < rsM2.length; j++) {
+ currentVal *= 1.0 - (rsM1i * rsM2[j]);
+ }
+ rsOut[i] = 1.0 - currentVal;
+ }
+ }
+ return rsOut;
+ }
+
+ private double[] estimInternCBind(double[] rsM1, double[] rsM2) {
+ // FIXME: this estimate assumes that the number of columns is
equivalent for both inputs
+ double[] rsOut = new double[rsM1.length];
+ for(int idx = 0; idx < rsM1.length; idx++) {
+ rsOut[idx] = (rsM1[idx] + rsM2[idx]) / 2.0;
+ }
+ return rsOut;
+ }
+
+ private double[] estimInternRBind(double[] rsM1, double[] rsM2) {
+ return ArrayUtils.addAll(rsM1, rsM2);
+ }
+
+ private double[] estimInternPlus(double[] rsM1, double[] rsM2) {
+ // row-wise average case estimates
+ // rsM1 + rsM2 - (rsM1 * rsM2)
+ double[] rsOut = new double[rsM1.length];
+ for(int idx = 0; idx < rsM1.length; idx++) {
+ rsOut[idx] = rsM1[idx] + rsM2[idx] - (rsM1[idx] *
rsM2[idx]);
+ }
+ return rsOut;
+ }
+
+ private double[] estimInternMult(double[] rsM1, double[] rsM2) {
+ // row-wise average case estimates
+ // rsM1 * rsM2
+ double[] rsOut = new double[rsM1.length];
+ for(int idx = 0; idx < rsM1.length; idx++) {
+ rsOut[idx] = rsM1[idx] * rsM2[idx];
+ }
+ return rsOut;
+ }
+
+ private double[] estimInternDiag(MatrixBlock mb) {
+ double[] rsOut = new double[mb.getNumRows()];
+ for(int rIdx = 0; rIdx < mb.getNumRows(); rIdx++) {
+ rsOut[rIdx] = (mb.get(rIdx, rIdx) == 0) ? 0 : 1;
+ }
+ return rsOut;
+ }
+
+ private double[] getRowWiseSparsityVector(MatrixBlock mb) {
+ int numRows = mb.getNumRows();
+ double[] rsOut = new double[numRows];
+ if(mb.isInSparseFormat()) {
+ for(int rIdx = 0; rIdx < numRows; rIdx++) {
+ SparseRow sparseRow =
mb.getSparseBlock().get(rIdx);
+ rsOut[rIdx] = (sparseRow == null) ? 0 :
(double) sparseRow.size() / mb.getNumColumns();
+ }
+ }
+ else {
+ for(int rIdx = 0; rIdx < numRows; rIdx++) {
+ rsOut[rIdx] = (double)
mb.getDenseBlock().countNonZeros(rIdx) / mb.getNumColumns();
+ }
+ }
+ return rsOut;
+ }
+
+ private int[] getNonZeroColumnIndices(MatrixBlock mb, final int rIdx) {
+ int[] nonZeroCols;
+ if(mb.isInSparseFormat()) {
+ SparseRow sparseRow = mb.getSparseBlock().get(rIdx);
+ nonZeroCols = (sparseRow == null) ? new int[0] :
sparseRow.indexes();
+ }
+ else {
+ nonZeroCols = IntStream.range(0, mb.getNumColumns())
+ .filter(cIdx -> mb.get(rIdx, cIdx) !=
0).toArray();
+ }
+ return nonZeroCols;
+ }
+
+ public static DataCharacteristics deriveOutputCharacteristics(MMNode
node, double spOut) {
+ if(node.isLeaf() ||
+ (node.getDataCharacteristics() != null &&
node.getDataCharacteristics().getNonZeros() != -1)) {
+ return node.getDataCharacteristics();
+ }
+
+ MMNode nodeLeft = node.getLeft();
+ MMNode nodeRight = node.getRight();
+ int leftNRow = nodeLeft.getRows();
+ int leftNCol = nodeLeft.getCols();
+ int rightNRow = nodeRight.getRows();
+ int rightNCol = nodeRight.getCols();
+ switch(node.getOp()) {
+ case MM:
+ return new MatrixCharacteristics(leftNRow,
rightNCol,
+ OptimizerUtils.getNnz(leftNRow,
rightNCol, spOut));
+ case MULT:
+ case PLUS:
+ case NEQZERO:
+ case EQZERO:
+ return new MatrixCharacteristics(leftNRow,
leftNCol,
+ OptimizerUtils.getNnz(leftNRow,
leftNCol, spOut));
+ case RBIND:
+ return new
MatrixCharacteristics(leftNRow+rightNRow, leftNCol,
+
OptimizerUtils.getNnz(leftNRow+rightNRow, leftNCol, spOut));
+ case CBIND:
+ return new MatrixCharacteristics(leftNRow,
leftNCol+rightNCol,
+ OptimizerUtils.getNnz(leftNRow,
leftNCol+rightNCol, spOut));
+ case DIAG:
+ int ncol = (leftNCol == 1) ? leftNRow : 1;
+ return new MatrixCharacteristics(leftNRow, ncol,
+ OptimizerUtils.getNnz(leftNRow, ncol,
spOut));
+ case TRANS:
+ return new MatrixCharacteristics(leftNCol,
leftNRow,
+ OptimizerUtils.getNnz(leftNCol,
leftNRow, spOut));
+ case RESHAPE:
+ throw new
NotImplementedException("Characteristics derivation for " + node.getOp() +" has
not been " +
+ "implemented yet, but could be
implemented similar to EstimatorMatrixHistogram.java");
+ default:
+ throw new NotImplementedException();
+ }
+ }
+};
diff --git
a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
index 35efedaf62..9626f9eb74 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
@@ -19,11 +19,11 @@
package org.apache.sysds.test.component.estim;
-import org.junit.Test;
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
+import org.apache.sysds.hops.estim.EstimatorRowWise;
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
import org.apache.sysds.hops.estim.MMNode;
import org.apache.sysds.hops.estim.SparsityEstimator;
@@ -32,136 +32,146 @@ import
org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
import org.apache.commons.lang3.NotImplementedException;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
/**
- * this is the basic operation check for all estimators with single operations
+ * this is the basic operation check for all estimators with chains of
operations including binding operations
*/
-public class OpBindChainTest extends AutomatedTestBase
+@RunWith(value = Parameterized.class)
+public class OpBindChainTest extends AutomatedTestBase
{
- private final static int m = 600;
- private final static int k = 300;
- private final static int n = 100;
- private final static double[] sparsity = new double[]{0.2, 0.4};
-// private final static OpCode mult = OpCode.MULT;
-// private final static OpCode plus = OpCode.PLUS;
- private final static OpCode rbind = OpCode.RBIND;
- private final static OpCode cbind = OpCode.CBIND;
-// private final static OpCode eqzero = OpCode.EQZERO;
-// private final static OpCode diag = OpCode.DIAG;
-// private final static OpCode neqzero = OpCode.NEQZERO;
-// private final static OpCode trans = OpCode.TRANS;
-// private final static OpCode reshape = OpCode.RESHAPE;
+ @Parameterized.Parameter(0)
+ public int m;
+ @Parameterized.Parameter(1)
+ public int k;
+ @Parameterized.Parameter(2)
+ public int n;
+ @Parameterized.Parameter(3)
+ public double[] sparsity;
@Override
public void setUp() {
//do nothing
}
-
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ // {m, k, n, sparsity}
+ {600, 300, 100, new double[]{0.2, 0.4}},
+ {600, 200, 300, new double[]{0.1, 0.15}},
+ });
+ }
+
//Average Case
@Test
public void testAvgRbind() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n,
sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), OpCode.RBIND);
}
@Test
public void testAvgCbind() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n,
sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), OpCode.CBIND);
}
//Worst Case
@Test
public void testWorstRbind() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n,
sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorBasicWorst(),
OpCode.RBIND);
}
@Test
public void testWorstCbind() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n,
sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorBasicWorst(),
OpCode.CBIND);
}
//DensityMap
/*@Test
public void testDMCaserbind() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n,
sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorDensityMap(),
OpCode.RBIND);
}
@Test
public void testDMCasecbind() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n,
sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorDensityMap(),
OpCode.CBIND);
}*/
//MNC
@Test
public void testMNCRbind() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k,
n, sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(),
OpCode.RBIND);
}
@Test
public void testMNCCbind() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k,
n, sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(),
OpCode.CBIND);
}
//Bitset
@Test
public void testBitsetCaserbind() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n,
sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), OpCode.RBIND);
}
@Test
public void testBitsetCasecbind() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n,
sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), OpCode.CBIND);
}
//Layered Graph
@Test
public void testLGCaserbind() {
runSparsityEstimateTest(
- new EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS,
7),
- m, k, n, sparsity, rbind);
+ new EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS,
7), OpCode.RBIND);
}
@Test
public void testLGCasecbind() {
runSparsityEstimateTest(
- new EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS,
3),
- m, k, n, sparsity, cbind);
+ new EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS,
3), OpCode.CBIND);
}
-
-
- private static void runSparsityEstimateTest(SparsityEstimator estim,
int m, int k, int n, double[] sp, OpCode op) {
- MatrixBlock m1;
+
+ // Row Wise Sparsity Estimator
+ @Test
+ public void testRowWiseRbind() {
+ runSparsityEstimateTest(new EstimatorRowWise(), OpCode.RBIND);
+ }
+
+ @Test
+ public void testRowWiseCbind() {
+ runSparsityEstimateTest(new EstimatorRowWise(), OpCode.CBIND);
+ }
+
+
+ private void runSparsityEstimateTest(SparsityEstimator estim, OpCode
op) {
+ MatrixBlock m1 = MatrixBlock.randOperations(m, k, sparsity[0],
1, 1, "uniform", 3);
MatrixBlock m2;
MatrixBlock m3 = new MatrixBlock();
MatrixBlock m4;
- MatrixBlock m5 = new MatrixBlock();
- double est = 0;
switch(op) {
case RBIND:
- m1 = MatrixBlock.randOperations(m, k, sp[0], 1,
1, "uniform", 3);
- m2 = MatrixBlock.randOperations(n, k, sp[1], 1,
1, "uniform", 7);
+ m2 = MatrixBlock.randOperations(n, k,
sparsity[1], 1, 1, "uniform", 7);
m1.append(m2, m3, false);
- m4 = MatrixBlock.randOperations(k, m, sp[1], 1,
1, "uniform", 5);
- m5 = m3.aggregateBinaryOperations(m3, m4,
- new MatrixBlock(),
InstructionUtils.getMatMultOperator(1));
- est = estim.estim(new MMNode(new MMNode(new
MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity();
- //System.out.println(est);
- //System.out.println(m5.getSparsity());
+ m4 = MatrixBlock.randOperations(k, m,
sparsity[1], 1, 1, "uniform", 5);
break;
case CBIND:
- m1 = MatrixBlock.randOperations(m, k, sp[0], 1,
1, "uniform", 3);
- m2 = MatrixBlock.randOperations(m, n, sp[1], 1,
1, "uniform", 7);
+ m2 = MatrixBlock.randOperations(m, n,
sparsity[1], 1, 1, "uniform", 7);
m1.append(m2, m3, true);
- m4 = MatrixBlock.randOperations(k+n, m, sp[1],
1, 1, "uniform", 5);
- m5 = m3.aggregateBinaryOperations(m3, m4,
- new MatrixBlock(),
InstructionUtils.getMatMultOperator(1));
- est = estim.estim(new MMNode(new MMNode(new
MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity();
- //System.out.println(est);
- //System.out.println(m5.getSparsity());
+ m4 = MatrixBlock.randOperations(k+n, m,
sparsity[1], 1, 1, "uniform", 5);
break;
default:
throw new NotImplementedException();
}
+ MatrixBlock m5 = m3.aggregateBinaryOperations(m3, m4,
+ new MatrixBlock(),
InstructionUtils.getMatMultOperator(1));
+ double est = estim.estim(new MMNode(new MMNode(new MMNode(m1),
new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity();
//compare estimated and real sparsity
TestUtils.compareScalars(est, m5.getSparsity(),
(estim instanceof EstimatorBasicWorst) ? 5e-1 :
diff --git
a/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java
b/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java
index 3e7ad24fe8..c943a06be1 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpBindTest.java
@@ -19,146 +19,163 @@
package org.apache.sysds.test.component.estim;
-import org.junit.Test;
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
+import org.apache.sysds.hops.estim.EstimatorRowWise;
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.hops.estim.SparsityEstimator.OpCode;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
import org.apache.commons.lang3.NotImplementedException;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
/**
- * this is the basic operation check for all estimators with single operations
+ * this is the basic operation check for all estimators with binding operations
*/
-public class OpBindTest extends AutomatedTestBase
+@RunWith(value = Parameterized.class)
+public class OpBindTest extends AutomatedTestBase
{
- private final static int m = 600;
- private final static int k = 300;
- private final static int n = 100;
- private final static double[] sparsity = new double[]{0.2, 0.4};
-// private final static OpCode mult = OpCode.MULT;
-// private final static OpCode plus = OpCode.PLUS;
- private final static OpCode rbind = OpCode.RBIND;
- private final static OpCode cbind = OpCode.CBIND;
-// private final static OpCode eqzero = OpCode.EQZERO;
-// private final static OpCode diag = OpCode.DIAG;
-// private final static OpCode neqzero = OpCode.NEQZERO;
-// private final static OpCode trans = OpCode.TRANS;
-// private final static OpCode reshape = OpCode.RESHAPE;
+ @Parameterized.Parameter(0)
+ public int m;
+ @Parameterized.Parameter(1)
+ public int k;
+ @Parameterized.Parameter(2)
+ public int n;
+ @Parameterized.Parameter(3)
+ public double[] sparsity;
@Override
public void setUp() {
//do nothing
}
-
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ // {m, k, n, sparsity}
+ {600, 300, 100, new double[]{0.2, 0.4}},
+ {600, 200, 300, new double[]{0.1, 0.15}},
+ });
+ }
+
//Average Case
@Test
public void testAvgRbind() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n,
sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), OpCode.RBIND);
}
@Test
public void testAvgCbind() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n,
sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), OpCode.CBIND);
}
//Worst Case
@Test
public void testWorstRbind() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n,
sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorBasicWorst(),
OpCode.RBIND);
}
@Test
public void testWorstCbind() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n,
sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorBasicWorst(),
OpCode.CBIND);
}
//DensityMap
/*@Test
public void testDMCaserbind() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n,
sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorDensityMap(),
OpCode.RBIND);
}
@Test
public void testDMCasecbind() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n,
sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorDensityMap(),
OpCode.CBIND);
}*/
//MNC
@Test
public void testMNCRbind() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k,
n, sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(),
OpCode.RBIND);
}
@Test
public void testMNCCbind() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k,
n, sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(),
OpCode.CBIND);
}
//Bitset
@Test
public void testBitsetCasecbind() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n,
sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), OpCode.CBIND);
}
@Test
public void testBitsetCaserbind() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n,
sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), OpCode.RBIND);
}
//Layered Graph
@Test
public void testLGCaserbind() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n,
sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorLayeredGraph(),
OpCode.RBIND);
}
@Test
public void testLGCasecbind() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n,
sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorLayeredGraph(),
OpCode.CBIND);
}
//Sample
/*@Test
public void testSampleCaserbind() {
- runSparsityEstimateTest(new EstimatorSample(), m, k, n,
sparsity, rbind);
+ runSparsityEstimateTest(new EstimatorSample(), OpCode.RBIND);
}
@Test
public void testSampleCasecbind() {
- runSparsityEstimateTest(new EstimatorSample(), m, k, n,
sparsity, cbind);
+ runSparsityEstimateTest(new EstimatorSample(), OpCode.CBIND);
}*/
+
+ // Row Wise Sparsity Estimator
+ @Test
+ public void testRowWiseRbind() {
+ runSparsityEstimateTest(new EstimatorRowWise(), OpCode.RBIND);
+ }
+
+ @Test
+ public void testRowWiseCbind() {
+ runSparsityEstimateTest(new EstimatorRowWise(), OpCode.CBIND);
+ }
+
-
- private static void runSparsityEstimateTest(SparsityEstimator estim,
int m, int k, int n, double[] sp, OpCode op) {
+ private void runSparsityEstimateTest(SparsityEstimator estim, OpCode
op) {
MatrixBlock m1;
MatrixBlock m2;
MatrixBlock m3 = new MatrixBlock();
- double est = 0;
switch(op) {
case RBIND:
- m1 = MatrixBlock.randOperations(m, k, sp[0], 1,
1, "uniform", 3);
- m2 = MatrixBlock.randOperations(n, k, sp[1], 1,
1, "uniform", 3);
+ m1 = MatrixBlock.randOperations(m, k,
sparsity[0], 1, 1, "uniform", 3);
+ m2 = MatrixBlock.randOperations(n, k,
sparsity[1], 1, 1, "uniform", 3);
m1.append(m2, m3, false);
- est = estim.estim(m1, m2, op);
- // System.out.println(est);
- // System.out.println(m3.getSparsity());
break;
case CBIND:
- m1 = MatrixBlock.randOperations(10, 130, sp[0],
1, 1, "uniform", 3);
- m2 = MatrixBlock.randOperations(10, 70, sp[1],
1, 1, "uniform", 3);
+ m1 = MatrixBlock.randOperations(10, 130,
sparsity[0], 1, 1, "uniform", 3);
+ m2 = MatrixBlock.randOperations(10, 70,
sparsity[1], 1, 1, "uniform", 3);
m1.append(m2, m3);
- est = estim.estim(m1, m2, op);
- // System.out.println(est);
- // System.out.println(m3.getSparsity());
break;
default:
throw new NotImplementedException();
}
+ double est = estim.estim(m1, m2, op);
//compare estimated and real sparsity
TestUtils.compareScalars(est, m3.getSparsity(),
(estim instanceof EstimatorBasicWorst) ? 5e-1 :
diff --git
a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
index a1b6594a92..da18067867 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
@@ -19,12 +19,12 @@
package org.apache.sysds.test.component.estim;
-import org.junit.Test;
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
+import org.apache.sysds.hops.estim.EstimatorRowWise;
import org.apache.sysds.hops.estim.EstimatorDensityMap;
import org.apache.sysds.hops.estim.MMNode;
import org.apache.sysds.hops.estim.SparsityEstimator;
@@ -36,120 +36,140 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
import org.apache.commons.lang3.NotImplementedException;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
/**
- * this is the basic operation check for all estimators with single operations
+ * this is the basic operation check for all estimators with chains of
operations including element-wise operations
*/
+@RunWith(value = Parameterized.class)
public class OpElemWChainTest extends AutomatedTestBase
{
- private final static int m = 1600;
- private final static int n = 700;
- private final static double[] sparsity = new double[]{0.1, 0.04};
- private final static OpCode mult = OpCode.MULT;
- private final static OpCode plus = OpCode.PLUS;
+ @Parameterized.Parameter(0)
+ public int m;
+ @Parameterized.Parameter(1)
+ public int n;
+ @Parameterized.Parameter(2)
+ public double[] sparsity;
@Override
public void setUp() {
//do nothing
}
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ // {m, n, sparsity}
+ {1600, 700, new double[]{0.1, 0.04}},
+ {900, 1200, new double[]{0.01, 0.125}},
+ });
+ }
+
//Average Case
@Test
public void testAvgMult() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), OpCode.MULT);
}
@Test
public void testAvgPlus() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), OpCode.PLUS);
}
//Worst Case
@Test
public void testWorstMult() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorBasicWorst(), OpCode.MULT);
}
@Test
public void testWorstPlus() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorBasicWorst(), OpCode.PLUS);
}
//DensityMap
@Test
public void testDMMult() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorDensityMap(), OpCode.MULT);
}
@Test
public void testDMPlus() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorDensityMap(), OpCode.PLUS);
}
//MNC
@Test
public void testMNCMult() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(),
OpCode.MULT);
}
@Test
public void testMNCPlus() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(),
OpCode.PLUS);
}
//Bitset
@Test
public void testBitsetMult() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), OpCode.MULT);
}
@Test
public void testBitsetPlus() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), OpCode.PLUS);
}
//Layered Graph
@Test
public void testLGCasemult() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13), OpCode.MULT);
}
@Test
public void testLGCaseplus() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13), OpCode.PLUS);
}
-
-
- private static void runSparsityEstimateTest(SparsityEstimator estim,
int m, int n, double[] sp, OpCode op) {
- MatrixBlock m1 = MatrixBlock.randOperations(m, n, sp[0], 1, 1,
"uniform", 3);
- MatrixBlock m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 1,
"uniform", 5);
- MatrixBlock m3 = MatrixBlock.randOperations(n, m, sp[1], 1, 1,
"uniform", 7);
+
+ // Row Wise Sparsity Estimator
+ @Test
+ public void testRowWiseCaseMult() {
+ runSparsityEstimateTest(new EstimatorRowWise(), OpCode.MULT);
+ }
+
+ @Test
+ public void testRowWiseCasePlus() {
+ runSparsityEstimateTest(new EstimatorRowWise(), OpCode.PLUS);
+ }
+
+ private void runSparsityEstimateTest(SparsityEstimator estim, OpCode
op) {
+ MatrixBlock m1 = MatrixBlock.randOperations(m, n, sparsity[0],
1, 1, "uniform", 3);
+ MatrixBlock m2 = MatrixBlock.randOperations(m, n, sparsity[1],
1, 1, "uniform", 5);
+ MatrixBlock m3 = MatrixBlock.randOperations(n, m, sparsity[1],
1, 1, "uniform", 7);
MatrixBlock m4 = new MatrixBlock();
- MatrixBlock m5 = new MatrixBlock();
BinaryOperator bOp;
- double est = 0;
switch(op) {
case MULT:
bOp = new
BinaryOperator(Multiply.getMultiplyFnObject());
- m1.binaryOperations(bOp, m2, m4);
- m5 = m4.aggregateBinaryOperations(m4, m3,
- new MatrixBlock(),
InstructionUtils.getMatMultOperator(1));
- est = estim.estim(new MMNode(new MMNode(new
MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity();
- // System.out.println(m5.getSparsity());
- // System.out.println(est);
break;
case PLUS:
bOp = new
BinaryOperator(Plus.getPlusFnObject());
- m1.binaryOperations(bOp, m2, m4);
- m5 = m4.aggregateBinaryOperations(m4, m3,
- new MatrixBlock(),
InstructionUtils.getMatMultOperator(1));
- est = estim.estim(new MMNode(new MMNode(new
MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity();
- // System.out.println(m5.getSparsity());
- // System.out.println(est);
break;
default:
throw new NotImplementedException();
}
+ m1.binaryOperations(bOp, m2, m4);
+ MatrixBlock m5 = m4.aggregateBinaryOperations(m4, m3,
+ new MatrixBlock(),
InstructionUtils.getMatMultOperator(1));
+ double est = estim.estim(new MMNode(new MMNode(new MMNode(m1),
new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity();
+
//compare estimated and real sparsity
TestUtils.compareScalars(est, m5.getSparsity(), (estim
instanceof EstimatorBasicWorst) ? 9e-1 :
(estim instanceof EstimatorLayeredGraph) ? 7e-2 : 1e-2);
diff --git
a/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java
b/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java
index f8ddb91bce..311ae50cb5 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpElemWTest.java
@@ -19,12 +19,12 @@
package org.apache.sysds.test.component.estim;
-import org.junit.Test;
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
import org.apache.sysds.hops.estim.EstimatorDensityMap;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
+import org.apache.sysds.hops.estim.EstimatorRowWise;
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
import org.apache.sysds.hops.estim.EstimatorSample;
import org.apache.sysds.hops.estim.SparsityEstimator;
@@ -35,124 +35,148 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
import org.apache.commons.lang3.NotImplementedException;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
/**
- * this is the basic operation check for all estimators with single operations
+ * this is the basic operation check for all estimators with element-wise
operations
*/
+@RunWith(value = Parameterized.class)
public class OpElemWTest extends AutomatedTestBase
{
- private final static int m = 1600;
- private final static int n = 700;
- private final static double[] sparsity = new double[]{0.2, 0.4};
- private final static OpCode mult = OpCode.MULT;
- private final static OpCode plus = OpCode.PLUS;
+ @Parameterized.Parameter(0)
+ public int m;
+ @Parameterized.Parameter(1)
+ public int n;
+ @Parameterized.Parameter(2)
+ public double[] sparsity;
@Override
public void setUp() {
//do nothing
}
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ // {m, n, sparsity}
+ {1600, 700, new double[]{0.2, 0.4}},
+ {900, 1200, new double[]{0.01, 0.125}},
+ });
+ }
+
//Average Case
@Test
public void testAvgMult() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), OpCode.MULT);
}
@Test
public void testAvgPlus() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), OpCode.PLUS);
}
//Worst Case
@Test
public void testWorstMult() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorBasicWorst(), OpCode.MULT);
}
@Test
public void testWorstPlus() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorBasicWorst(), OpCode.PLUS);
}
//DensityMap
@Test
public void testDMMult() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorDensityMap(), OpCode.MULT);
}
@Test
public void testDMPlus() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorDensityMap(), OpCode.PLUS);
}
//MNC
@Test
public void testMNCMult() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(),
OpCode.MULT);
}
@Test
public void testMNCPlus() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(),
OpCode.PLUS);
}
//Bitset
@Test
public void testBitsetMult() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), OpCode.MULT);
}
@Test
public void testBitsetPlus() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), OpCode.PLUS);
}
//Layered Graph
@Test
public void testLGCasemult() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, n,
sparsity, mult);
+ runSparsityEstimateTest(new EstimatorLayeredGraph(),
OpCode.MULT);
}
@Test
public void testLGCaseplus() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, n,
sparsity, plus);
+ runSparsityEstimateTest(new EstimatorLayeredGraph(),
OpCode.PLUS);
}
//Sample
@Test
public void testSampleMult() {
- runSparsityEstimateTest(new EstimatorSample(), m, n, sparsity,
mult);
+ runSparsityEstimateTest(new EstimatorSample(), OpCode.MULT);
}
@Test
public void testSamplePlus() {
- runSparsityEstimateTest(new EstimatorSample(), m, n, sparsity,
plus);
+ runSparsityEstimateTest(new EstimatorSample(), OpCode.PLUS);
}
-
- private static void runSparsityEstimateTest(SparsityEstimator estim,
int m, int n, double[] sp, OpCode op) {
- MatrixBlock m1 = MatrixBlock.randOperations(m, n, sp[0], 1, 1,
"uniform", 3);
- MatrixBlock m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 1,
"uniform", 7);
+
+ // Row Wise Sparsity Estimator
+ @Test
+ public void testRowWiseMult() {
+ runSparsityEstimateTest(new EstimatorRowWise(), OpCode.MULT);
+ }
+
+ @Test
+ public void testRowWisePlus() {
+ runSparsityEstimateTest(new EstimatorRowWise(), OpCode.PLUS);
+ }
+
+ private void runSparsityEstimateTest(SparsityEstimator estim, OpCode
op) {
+ MatrixBlock m1 = MatrixBlock.randOperations(m, n, sparsity[0],
1, 1, "uniform", 3);
+ MatrixBlock m2 = MatrixBlock.randOperations(m, n, sparsity[1],
1, 1, "uniform", 7);
MatrixBlock m3 = new MatrixBlock();
BinaryOperator bOp;
- double est = 0;
switch(op) {
case MULT:
bOp = new
BinaryOperator(Multiply.getMultiplyFnObject());
- m1.binaryOperations(bOp, m2, m3);
- est = estim.estim(m1, m2, op);
- // System.out.println(est);
- // System.out.println(m3.getSparsity());
break;
case PLUS:
bOp = new
BinaryOperator(Plus.getPlusFnObject());
- m1.binaryOperations(bOp, m2, m3);
- est = estim.estim(m1, m2, op);
- // System.out.println(est);
- // System.out.println(m3.getSparsity());
break;
- default:
- throw new NotImplementedException();
+ default:
+ throw new NotImplementedException();
}
+ m1.binaryOperations(bOp, m2, m3);
+ double est = estim.estim(m1, m2, op);
//compare estimated and real sparsity
TestUtils.compareScalars(est, m3.getSparsity(), (estim
instanceof EstimatorBasicWorst) ? 5e-1 :
(estim instanceof EstimatorLayeredGraph) ? 3e-2 : 5e-3);
diff --git
a/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java
b/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java
index d40f84c4fb..14696fa572 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpSingleTest.java
@@ -19,255 +19,297 @@
package org.apache.sysds.test.component.estim;
-import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
-import org.junit.Test;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
+import org.apache.sysds.hops.estim.EstimatorRowWise;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.hops.estim.SparsityEstimator.OpCode;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
/**
* this is the basic operation check for all estimators with single operations
*/
+@RunWith(value = Parameterized.class)
public class OpSingleTest extends AutomatedTestBase
{
- private final static int m = 600;
- private final static int k = 300;
- private final static double sparsity = 0.2;
-// private final static OpCode eqzero = OpCode.EQZERO;
- private final static OpCode diag = OpCode.DIAG;
- private final static OpCode neqzero = OpCode.NEQZERO;
- private final static OpCode trans = OpCode.TRANS;
- private final static OpCode reshape = OpCode.RESHAPE;
+ @Parameterized.Parameter(0)
+ public int m;
+ @Parameterized.Parameter(1)
+ public int k_param;
+ @Parameterized.Parameter(2)
+ public double sparsity;
@Override
public void setUp() {
//do nothing
}
-
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ // {m, k_param, sparsity}
+ {600, 300, 0.2},
+ {200, 1200, 0.6},
+ });
+ }
+
//Average Case
// @Test
// public void testAvgEqzero() {
-// runSparsityEstimateTest(new EstimatorBasicAvg(), m, k,
sparsity, eqzero);
+// runSparsityEstimateTest(new EstimatorBasicAvg(), k_param,
OpCode.EQZERO);
// }
// @Test
// public void testAvgDiag() {
-// runSparsityEstimateTest(new EstimatorBasicAvg(), m, m,
sparsity, diag);
+// runSparsityEstimateTest(new EstimatorBasicAvg(), m,
OpCode.DIAG);
// }
@Test
public void testAvgNeqzero() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k,
sparsity, neqzero);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), k_param,
OpCode.NEQZERO);
}
@Test
public void testAvgTrans() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k,
sparsity, trans);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), k_param,
OpCode.TRANS);
}
@Test
public void testAvgReshape() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k,
sparsity, reshape);
+ runSparsityEstimateTest(new EstimatorBasicAvg(), k_param,
OpCode.RESHAPE);
}
//Worst Case
// @Test
// public void testWorstEqzero() {
-// runSparsityEstimateTest(new EstimatorBasicWorst(), m, k,
sparsity, eqzero);
+// runSparsityEstimateTest(new EstimatorBasicWorst(), k_param,
OpCode.EQZERO);
// }
// @Test
// public void testWCasediag() {
-// runSparsityEstimateTest(new EstimatorBasicWorst(), m, m,
sparsity, diag);
+// runSparsityEstimateTest(new EstimatorBasicWorst(), m,
OpCode.DIAG);
// }
@Test
public void testWorstNeqzero() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k,
sparsity, neqzero);
+ runSparsityEstimateTest(new EstimatorBasicWorst(), k_param,
OpCode.NEQZERO);
}
@Test
public void testWoestTrans() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k,
sparsity, trans);
+ runSparsityEstimateTest(new EstimatorBasicWorst(), k_param,
OpCode.TRANS);
}
@Test
public void testWorstReshape() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k,
sparsity, reshape);
+ runSparsityEstimateTest(new EstimatorBasicWorst(), k_param,
OpCode.RESHAPE);
}
// //DensityMap
// @Test
// public void testDMCaseeqzero() {
-// runSparsityEstimateTest(new EstimatorDensityMap(), m, k,
sparsity, eqzero);
+// runSparsityEstimateTest(new EstimatorDensityMap(), k_param,
OpCode.EQZERO);
// }
//
// @Test
// public void testDMCasediag() {
-// runSparsityEstimateTest(new EstimatorDensityMap(), m, m,
sparsity, diag);
+// runSparsityEstimateTest(new EstimatorDensityMap(), m,
OpCode.DIAG);
// }
//
// @Test
// public void testDMCaseneqzero() {
-// runSparsityEstimateTest(new EstimatorDensityMap(), m, k,
sparsity, neqzero);
+// runSparsityEstimateTest(new EstimatorDensityMap(), k_param,
OpCode.NEQZERO);
// }
//
// @Test
// public void testDMCasetrans() {
-// runSparsityEstimateTest(new EstimatorDensityMap(), m, k,
sparsity, trans);
+// runSparsityEstimateTest(new EstimatorDensityMap(), k_param,
OpCode.TRANS);
// }
//
// @Test
// public void testDMCasereshape() {
-// runSparsityEstimateTest(new EstimatorDensityMap(), m, k,
sparsity, reshape);
+// runSparsityEstimateTest(new EstimatorDensityMap(), k_param,
OpCode.RESHAPE);
// }
//
// //MNC
// @Test
// public void testMNCCaseeqzero() {
-// runSparsityEstimateTest(new EstimatorDensityMap(), m, k,
sparsity, eqzero);
+// runSparsityEstimateTest(new EstimatorDensityMap(), k_param,
OpCode.EQZERO);
// }
//
// @Test
// public void testMNCCasediag() {
-// runSparsityEstimateTest(new EstimatorDensityMap(), m, m,
sparsity, diag);
+// runSparsityEstimateTest(new EstimatorDensityMap(), m,
OpCode.DIAG);
// }
//
// @Test
// public void testMNCCaseneqzero() {
-// runSparsityEstimateTest(new EstimatorDensityMap(), m, k,
sparsity, neqzero);
+// runSparsityEstimateTest(new EstimatorDensityMap(), k_param,
OpCode.NEQZERO);
// }
//
// @Test
// public void testMNCCasetrans() {
-// runSparsityEstimateTest(new EstimatorDensityMap(), m, k,
sparsity, trans);
+// runSparsityEstimateTest(new EstimatorDensityMap(), k_param,
OpCode.TRANS);
// }
//
// @Test
// public void testMNCCasereshape() {
-// runSparsityEstimateTest(new EstimatorDensityMap(), m, k,
sparsity, reshape);
+// runSparsityEstimateTest(new EstimatorDensityMap(), k_param,
OpCode.RESHAPE);
// }
//
//Bitset
// @Test
// public void testBitsetCaseeqzero() {
-// runSparsityEstimateTest(new EstimatorBitsetMM(), m, k,
sparsity, eqzero);
+// runSparsityEstimateTest(new EstimatorBitsetMM(), k_param,
OpCode.EQZERO);
// }
// @Test
// public void testBitsetCasediag() {
-// runSparsityEstimateTest(new EstimatorBitsetMM(), m, m,
sparsity, diag);
+// runSparsityEstimateTest(new EstimatorBitsetMM(), m,
OpCode.DIAG);
// }
@Test
public void testBitsetNeqzero() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k,
sparsity, neqzero);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), k_param,
OpCode.NEQZERO);
}
@Test
public void testBitsetTrans() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k,
sparsity, trans);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), k_param,
OpCode.TRANS);
}
@Test
public void testBitsetReshape() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k,
sparsity, reshape);
+ runSparsityEstimateTest(new EstimatorBitsetMM(), k_param,
OpCode.RESHAPE);
}
// //Layered Graph
// @Test
// public void testLGCaseeqzero() {
-// runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k,
sparsity, eqzero);
+// runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13), k_param,
OpCode.EQZERO);
// }
//
@Test
public void testLGCasediagM() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, m,
sparsity, diag);
+ runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13), m, OpCode.DIAG);
}
@Test
public void testLGCasediagV() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, 1,
sparsity, diag);
+ runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13), 1, OpCode.DIAG);
}
//
// @Test
// public void testLGCaseneqzero() {
-// runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k,
sparsity, neqzero);
+// runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13), k_param,
OpCode.NEQZERO);
// }
//
@Test
public void testLGCasetrans() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k,
sparsity, trans);
+ runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13), k_param, OpCode.TRANS);
}
// @Test
// public void testLGCasereshape() {
-// runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k,
sparsity, reshape);
+// runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13), k_param,
OpCode.RESHAPE);
// }
//
// //Sample
// @Test
// public void testSampleCaseeqzero() {
-// runSparsityEstimateTest(new EstimatorSample(), m, k, sparsity,
eqzero);
+// runSparsityEstimateTest(new EstimatorSample(), k_param,
OpCode.EQZERO);
// }
//
// @Test
// public void testSampleCasediag() {
-// runSparsityEstimateTest(new EstimatorSample(), m, m, sparsity,
diag);
+// runSparsityEstimateTest(new EstimatorSample(), m, OpCode.DIAG);
// }
//
// @Test
// public void testSampleCaseneqzero() {
-// runSparsityEstimateTest(new EstimatorSample(), m, k, sparsity,
neqzero);
+// runSparsityEstimateTest(new EstimatorSample(), k_param,
OpCode.NEQZERO);
// }
//
// @Test
// public void testSampleCasetrans() {
-// runSparsityEstimateTest(new EstimatorSample(), m, k, sparsity,
trans);
+// runSparsityEstimateTest(new EstimatorSample(), k_param,
OpCode.TRANS);
// }
//
// @Test
// public void testSampleCasereshape() {
-// runSparsityEstimateTest(new EstimatorSample(), m, k, sparsity,
reshape);
+// runSparsityEstimateTest(new EstimatorSample(), k_param,
OpCode.RESHAPE);
// }
-
- private static void runSparsityEstimateTest(SparsityEstimator estim,
int m, int k, double sp, OpCode op) {
- MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp, 1, 1,
"uniform", 3);
- MatrixBlock m2 = new MatrixBlock();
- double est = 0;
+
+ // Row Wise Sparsity Estimator
+ @Test
+ public void testRowWiseEqzero() {
+ runSparsityEstimateTest(new EstimatorRowWise(), k_param,
OpCode.EQZERO);
+ }
+
+ @Test
+ public void testRowWiseDiagM() {
+ runSparsityEstimateTest(new EstimatorRowWise(), m, OpCode.DIAG);
+ }
+
+ @Test
+ public void testRowWiseDiagV() {
+ runSparsityEstimateTest(new EstimatorRowWise(), 1, OpCode.DIAG);
+ }
+
+ @Test
+ public void testRowWiseNeqzero() {
+ runSparsityEstimateTest(new EstimatorRowWise(), k_param,
OpCode.NEQZERO);
+ }
+
+ @Test
+ public void testRowWiseTrans() {
+ runSparsityEstimateTest(new EstimatorRowWise(), k_param,
OpCode.TRANS);
+ }
+
+ @Test
+ public void testRowWiseReshape() {
+ runSparsityEstimateTest(new EstimatorRowWise(), k_param,
OpCode.RESHAPE);
+ }
+
+ private void runSparsityEstimateTest(SparsityEstimator estim, int k,
OpCode op) {
+ MatrixBlock m1 = MatrixBlock.randOperations(m, k, sparsity, 1,
1, "uniform", 3);
+ MatrixBlock m2;
+ double ref = -1;
switch(op) {
case EQZERO:
- //TODO find out how to do eqzero
+ ref = 1 - m1.getSparsity();
+ break;
case DIAG:
m2 = m1.getNumColumns() == 1
? LibMatrixReorg.diag(m1, new
MatrixBlock(m1.getNumRows(), m1.getNumRows(), false))
: LibMatrixReorg.diag(m1, new
MatrixBlock(m1.getNumRows(), 1, false));
- est = estim.estim(m1, op);
+ ref = m2.getSparsity();
break;
case NEQZERO:
- m2 = m1;
- est = estim.estim(m1, op);
- break;
case TRANS:
- m2 = m1;
- est = estim.estim(m1, op);
- break;
case RESHAPE:
m2 = m1;
- est = estim.estim(m1, op);
+ ref = m2.getSparsity();
break;
default:
throw new NotImplementedException();
}
+ double est = estim.estim(m1, op);
//compare estimated and real sparsity
- TestUtils.compareScalars(est, m2.getSparsity(),
+ TestUtils.compareScalars(est, ref,
(estim instanceof EstimatorBasicWorst) ? 5e-1 :
(estim instanceof EstimatorLayeredGraph) ? 3e-2 : 2e-2);
}
diff --git
a/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java
b/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java
index fdc33d878d..f0486a58ca 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OuterProductTest.java
@@ -20,12 +20,19 @@
package org.apache.sysds.test.component.estim;
import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
import org.apache.sysds.hops.estim.EstimatorDensityMap;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
+import org.apache.sysds.hops.estim.EstimatorRowWise;
import org.apache.sysds.hops.estim.EstimatorSample;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -37,122 +44,90 @@ import org.apache.sysds.test.TestUtils;
* This is a basic sanity check for all estimator, which need
* to compute the exact sparsity for the special case of outer products.
*/
-public class OuterProductTest extends AutomatedTestBase
+@RunWith(value = Parameterized.class)
+public class OuterProductTest extends AutomatedTestBase
{
- private final static int m = 1154;
- private final static int k = 1;
- private final static int n = 900;
- private final static double[] case1 = new double[]{0.1, 0.7};
- private final static double[] case2 = new double[]{0.6, 0.7};
+ @Parameterized.Parameter(0)
+ public int m;
+ @Parameterized.Parameter(1)
+ public int k;
+ @Parameterized.Parameter(2)
+ public int n;
+ @Parameterized.Parameter(3)
+ public double[] sparsity;
@Override
public void setUp() {
//do nothing
}
- @Test
- public void testBasicAvgCase1() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n,
case1);
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ // {m, k, n, sparsity}
+ {1154, 1, 900, new double[]{0.1, 0.7}},
+ {1154, 1, 900, new double[]{0.6, 0.7}},
+ });
}
-
+
@Test
- public void testBasicAvgCase2() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n,
case2);
+ public void testBasicAvgCase1() {
+ runSparsityEstimateTest(new EstimatorBasicAvg());
}
@Test
public void testBasicWorstCase1() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n,
case1);
- }
-
- @Test
- public void testBasicWorstCase2() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n,
case2);
+ runSparsityEstimateTest(new EstimatorBasicWorst());
}
@Test
public void testDensityMapCase1() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n,
case1);
- }
-
- @Test
- public void testDensityMapCase2() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n,
case2);
+ runSparsityEstimateTest(new EstimatorDensityMap());
}
@Test
public void testDensityMap7Case1() {
- runSparsityEstimateTest(new EstimatorDensityMap(7), m, k, n,
case1);
- }
-
- @Test
- public void testDensityMap7Case2() {
- runSparsityEstimateTest(new EstimatorDensityMap(7), m, k, n,
case2);
+ runSparsityEstimateTest(new EstimatorDensityMap(7));
}
@Test
public void testBitsetMatrixCase1() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n,
case1);
- }
-
- @Test
- public void testBitsetMatrixCase2() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n,
case2);
+ runSparsityEstimateTest(new EstimatorBitsetMM());
}
@Test
public void testMatrixHistogramCase1() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m,
k, n, case1);
- }
-
- @Test
- public void testMatrixHistogramCase2() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m,
k, n, case2);
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(false));
}
@Test
public void testMatrixHistogramExceptCase1() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m,
k, n, case1);
- }
-
- @Test
- public void testMatrixHistogramExceptCase2() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m,
k, n, case2);
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(true));
}
@Test
public void testSamplingDefCase1() {
- runSparsityEstimateTest(new EstimatorSample(), m, k, n, case1);
- }
-
- @Test
- public void testSamplingDefCase2() {
- runSparsityEstimateTest(new EstimatorSample(), m, k, n, case2);
+ runSparsityEstimateTest(new EstimatorSample());
}
@Test
public void testSampling20Case1() {
- runSparsityEstimateTest(new EstimatorSample(0.2), m, k, n,
case1);
- }
-
- @Test
- public void testSampling20Case2() {
- runSparsityEstimateTest(new EstimatorSample(0.2), m, k, n,
case2);
+ runSparsityEstimateTest(new EstimatorSample(0.2));
}
@Test
public void testLayeredGraphCase1() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n,
case1);
+ runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13));
}
@Test
- public void testLayeredGraphCase2() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n,
case2);
+ public void testRowWiseCase1() {
+ runSparsityEstimateTest(new EstimatorRowWise());
}
- private static void runSparsityEstimateTest(SparsityEstimator estim,
int m, int k, int n, double[] sp) {
- MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1,
"uniform", 3);
- MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1,
"uniform", 3);
+ private void runSparsityEstimateTest(SparsityEstimator estim) {
+ MatrixBlock m1 = MatrixBlock.randOperations(m, k, sparsity[0],
1, 1, "uniform", 3);
+ MatrixBlock m2 = MatrixBlock.randOperations(k, n, sparsity[1],
1, 1, "uniform", 3);
MatrixBlock m3 = m1.aggregateBinaryOperations(m1, m2,
new MatrixBlock(),
InstructionUtils.getMatMultOperator(1));
diff --git
a/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java
b/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java
index d99f38d939..58e7f2195c 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/SelfProductTest.java
@@ -19,7 +19,14 @@
package org.apache.sysds.test.component.estim;
+import org.junit.Assume;
import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.estim.EstimationUtils;
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
@@ -28,6 +35,7 @@ import org.apache.sysds.hops.estim.EstimatorBitsetMM;
import org.apache.sysds.hops.estim.EstimatorDensityMap;
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
+import org.apache.sysds.hops.estim.EstimatorRowWise;
import org.apache.sysds.hops.estim.EstimatorSample;
import org.apache.sysds.hops.estim.EstimatorSampleRa;
import org.apache.sysds.hops.estim.SparsityEstimator;
@@ -36,142 +44,118 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
-public class SelfProductTest extends AutomatedTestBase
+@RunWith(value = Parameterized.class)
+public class SelfProductTest extends AutomatedTestBase
{
- private final static int m = 2500;
- private final static double sparsity0 = 0.5;
- private final static double sparsity1 = 0.1;
- private final static double sparsity2 = 0.0001;
- private final static double sparsity3 = 0.000001;
- private final static double eps1 = 0.05;
- private final static double eps2 = 1e-4;
- private final static double eps3 = 0;
-
+ @Parameterized.Parameter(0)
+ public int m;
+ @Parameterized.Parameter(1)
+ public double sparsity;
@Override
public void setUp() {
//do nothing
}
-
- @Test
- public void testBasicAvgCase() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m/4,
sparsity0);
- runSparsityEstimateTest(new EstimatorBasicAvg(), m/2,
sparsity1);
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, sparsity2);
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, sparsity3);
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ // {m, sparsity}
+ {625, 0.5},
+ {1250, 0.1},
+ {2500, 0.0001},
+ {2500, 0.000001},
+ });
}
-
+
@Test
- public void testDensityMapCase() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m/4,
sparsity0);
- runSparsityEstimateTest(new EstimatorDensityMap(), m/2,
sparsity1);
- runSparsityEstimateTest(new EstimatorDensityMap(), m,
sparsity2);
- runSparsityEstimateTest(new EstimatorDensityMap(), m,
sparsity3);
+ public void testBasicAvg() {
+ runSparsityEstimateTest(new EstimatorBasicAvg());
}
@Test
- public void testDensityMap7Case() {
- runSparsityEstimateTest(new EstimatorDensityMap(7), m/4,
sparsity0);
- runSparsityEstimateTest(new EstimatorDensityMap(7), m/2,
sparsity1);
- runSparsityEstimateTest(new EstimatorDensityMap(7), m,
sparsity2);
- runSparsityEstimateTest(new EstimatorDensityMap(7), m,
sparsity3);
+ public void testDensityMap() {
+ runSparsityEstimateTest(new EstimatorDensityMap());
}
@Test
- public void testBitsetMatrixCase() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m/4,
sparsity0);
- runSparsityEstimateTest(new EstimatorBitsetMM(), m/2,
sparsity1);
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, sparsity2);
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, sparsity3);
+ public void testDensityMapBlocksize7() {
+ runSparsityEstimateTest(new EstimatorDensityMap(7));
}
@Test
- public void testBitset2MatrixCase() {
- runSparsityEstimateTest(new EstimatorBitsetMM(2), m/4,
sparsity0);
- runSparsityEstimateTest(new EstimatorBitsetMM(2), m/2,
sparsity1);
- runSparsityEstimateTest(new EstimatorBitsetMM(2), m, sparsity2);
- runSparsityEstimateTest(new EstimatorBitsetMM(2), m, sparsity3);
+ public void testBitsetMatrix() {
+ runSparsityEstimateTest(new EstimatorBitsetMM());
}
@Test
- public void testMatrixHistogramCase() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(false),
m/4, sparsity0);
- runSparsityEstimateTest(new EstimatorMatrixHistogram(false),
m/2, sparsity1);
- runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m,
sparsity2);
- runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m,
sparsity3);
+ public void testBitsetMatrixType2() {
+ runSparsityEstimateTest(new EstimatorBitsetMM(2));
}
@Test
- public void testMatrixHistogramExceptCase() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(true),
m/4, sparsity0);
- runSparsityEstimateTest(new EstimatorMatrixHistogram(true),
m/2, sparsity1);
- runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m,
sparsity2);
- runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m,
sparsity3);
+ public void testMatrixHistogram() {
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(false));
}
@Test
- public void testSamplingDefCase() {
- runSparsityEstimateTest(new EstimatorSample(), m, sparsity2);
- runSparsityEstimateTest(new EstimatorSample(), m, sparsity3);
+ public void testMatrixHistogramExtended() {
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(true));
}
@Test
- public void testSampling20Case() {
- runSparsityEstimateTest(new EstimatorSample(0.2), m, sparsity2);
- runSparsityEstimateTest(new EstimatorSample(0.2), m, sparsity3);
+ public void testSampling() {
+ Assume.assumeTrue(sparsity < 0.1);
+ runSparsityEstimateTest(new EstimatorSample());
}
@Test
- public void testSamplingRaDefCase() {
- runSparsityEstimateTest(new EstimatorSampleRa(), m/4,
sparsity0);
- runSparsityEstimateTest(new EstimatorSampleRa(), m, sparsity2);
- runSparsityEstimateTest(new EstimatorSampleRa(), m, sparsity3);
+ public void testSamplingFrac20() {
+ Assume.assumeTrue(sparsity < 0.1);
+ runSparsityEstimateTest(new EstimatorSample(0.2));
}
@Test
- public void testSamplingRa20Case() {
- runSparsityEstimateTest(new EstimatorSampleRa(0.2), m/4,
sparsity0);
- runSparsityEstimateTest(new EstimatorSampleRa(0.2), m,
sparsity2);
- runSparsityEstimateTest(new EstimatorSampleRa(0.2), m,
sparsity3);
+ public void testSamplingRa() {
+ runSparsityEstimateTest(new EstimatorSampleRa());
}
@Test
- public void testLayeredGraphDefCase() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m,
sparsity2);
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m,
sparsity3);
+ public void testSamplingRaFrac20() {
+ runSparsityEstimateTest(new EstimatorSampleRa(0.2));
}
@Test
- public void testLayeredGraph64Case() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(64), m,
sparsity2);
- runSparsityEstimateTest(new EstimatorLayeredGraph(64), m,
sparsity3);
+ public void testLayeredGraph() {
+ runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13));
}
@Test
- public void testLayeredGraphCase1() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m,
sparsity1);
+ public void testLayeredGraph64Rounds() {
+ runSparsityEstimateTest(new EstimatorLayeredGraph(64, 13));
}
@Test
- public void testLayeredGraphCase2() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m,
sparsity2);
+ public void testRowWise() {
+ runSparsityEstimateTest(new EstimatorRowWise());
}
-
- private static void runSparsityEstimateTest(SparsityEstimator estim,
int n, double sp) {
- MatrixBlock m1 = MatrixBlock.randOperations(n, n, sp, 1, 1,
"uniform", 3);
+
+ private void runSparsityEstimateTest(SparsityEstimator estim) {
+ MatrixBlock m1 = MatrixBlock.randOperations(m, m, sparsity, 1,
1, "uniform", 3);
MatrixBlock m3 = m1.aggregateBinaryOperations(m1, m1,
new MatrixBlock(),
InstructionUtils.getMatMultOperator(1));
- double spExact1 = OptimizerUtils.getSparsity(n, n,
+ double spExact1 = OptimizerUtils.getSparsity(m, m,
EstimationUtils.getSelfProductOutputNnz(m1));
- double spExact2 = sp<0.4 ? OptimizerUtils.getSparsity(n, n,
+ double spExact2 = sparsity<0.4 ? OptimizerUtils.getSparsity(m,
m,
EstimationUtils.getSparseProductOutputNnz(m1, m1)) :
spExact1;
//compare estimated and real sparsity
double est = estim.estim(m1, m1);
TestUtils.compareScalars(est, m3.getSparsity(),
- (estim instanceof EstimatorBitsetMM) ? eps3 : //exact
- (estim instanceof EstimatorBasicWorst || estim
instanceof EstimatorLayeredGraph) ? eps1 : eps2);
- TestUtils.compareScalars(m3.getSparsity(), spExact1, eps3);
- TestUtils.compareScalars(m3.getSparsity(), spExact2, eps3);
+ (estim instanceof EstimatorBitsetMM) ? 0 : //exact
+ (estim instanceof EstimatorBasicWorst || estim
instanceof EstimatorLayeredGraph) ? 0.05 :
+ (sparsity == 0.1 && estim instanceof EstimatorSampleRa)
? 0.12 : 1e-4);
+ TestUtils.compareScalars(m3.getSparsity(), spExact1, 0);
+ TestUtils.compareScalars(m3.getSparsity(), spExact2, 0);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
index f799b02c96..a2b04b34df 100644
---
a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
+++
b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
@@ -19,13 +19,13 @@
package org.apache.sysds.test.component.estim;
-import org.junit.Test;
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
import org.apache.sysds.hops.estim.EstimatorDensityMap;
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
+import org.apache.sysds.hops.estim.EstimatorRowWise;
import org.apache.sysds.hops.estim.MMNode;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.hops.estim.SparsityEstimator.OpCode;
@@ -34,123 +34,99 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
/**
* This is a basic sanity check for all estimator, which need
* to compute a reasonable estimate for uniform data.
*/
-public class SquaredProductChainTest extends AutomatedTestBase
+@RunWith(value = Parameterized.class)
+public class SquaredProductChainTest extends AutomatedTestBase
{
- private final static int m = 1000;
- private final static int k = 1000;
- private final static int n = 1000;
- private final static int n2 = 1000;
- private final static double[] case1 = new double[]{0.0001, 0.00007,
0.001};
- private final static double[] case2 = new double[]{0.0006, 0.00007,
0.001};
+ @Parameterized.Parameter(0)
+ public int m;
+ @Parameterized.Parameter(1)
+ public int k;
+ @Parameterized.Parameter(2)
+ public int n;
+ @Parameterized.Parameter(3)
+ public int n2;
+ @Parameterized.Parameter(4)
+ public double[] sparsity;
- private final static double eps1 = 1.0;
- private final static double eps2 = 1e-4;
- private final static double eps3 = 0;
-
-
@Override
public void setUp() {
//do nothing
}
-
- @Test
- public void testBasicAvgCase1() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n, n2,
case1);
- }
-
- @Test
- public void testBasicAvgCase2() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n, n2,
case2);
- }
-
- @Test
- public void testBasicWorstCase1() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n, n2,
case1);
- }
-
- @Test
- public void testBasicWorstCase2() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n, n2,
case2);
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ // {m, k, n, n2, sparsity}
+ {1000, 1000, 1000, 1000, new double[]{0.0001, 0.00007,
0.001}},
+ {1000, 1000, 1000, 1000, new double[]{0.0006, 0.00007,
0.001}},
+ });
}
@Test
- public void testDensityMapCase1() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n, n2,
case1);
+ public void testBasicAvg() {
+ runSparsityEstimateTest(new EstimatorBasicAvg());
}
@Test
- public void testDensityMapCase2() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n, n2,
case2);
+ public void testBasicWorst() {
+ runSparsityEstimateTest(new EstimatorBasicWorst());
}
@Test
- public void testDensityMap7Case1() {
- runSparsityEstimateTest(new EstimatorDensityMap(7), m, k, n,
n2, case1);
+ public void testDensityMap() {
+ runSparsityEstimateTest(new EstimatorDensityMap());
}
@Test
- public void testDensityMap7Case2() {
- runSparsityEstimateTest(new EstimatorDensityMap(7), m, k, n,
n2, case2);
+ public void testDensityMapBlocksize7() {
+ runSparsityEstimateTest(new EstimatorDensityMap(7));
}
@Test
- public void testBitsetMatrixCase1() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n, n2,
case1);
+ public void testBitsetMatrix() {
+ runSparsityEstimateTest(new EstimatorBitsetMM());
}
@Test
- public void testBitsetMatrixCase2() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n, n2,
case2);
+ public void testMatrixHistogram() {
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(false));
}
@Test
- public void testMatrixHistogramCase1() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m,
k, n, n2, case1);
+ public void testMatrixHistogramExcept() {
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(true));
}
@Test
- public void testMatrixHistogramCase2() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m,
k, n, n2, case2);
+ public void testLayeredGraph() {
+ runSparsityEstimateTest(new
EstimatorLayeredGraph(EstimatorLayeredGraph.ROUNDS, 13));
}
@Test
- public void testMatrixHistogramExceptCase1() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m,
k, n, n2, case1);
+ public void testLayeredGraph32Rounds() {
+ runSparsityEstimateTest(new EstimatorLayeredGraph(32, 13));
}
@Test
- public void testMatrixHistogramExceptCase2() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m,
k, n, n2, case2);
- }
-
- @Test
- public void testLayeredGraphCase1() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n,
n2, case1);
+ public void testRowWise() {
+ runSparsityEstimateTest(new EstimatorRowWise());
}
- @Test
- public void testLayeredGraphCase2() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n,
n2, case2);
- }
-
- @Test
- public void testLayeredGraph32Case1() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(32), m, k, n,
n2, case1);
- }
-
- @Test
- public void testLayeredGraph32Case2() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(32), m, k, n,
n2, case2);
- }
-
- private static void runSparsityEstimateTest(SparsityEstimator estim,
int m, int k, int n, int n2, double[] sp) {
- MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1,
"uniform", 1);
- MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1,
"uniform", 2);
- MatrixBlock m3 = MatrixBlock.randOperations(n, n2, sp[2], 1, 1,
"uniform", 3);
+ private void runSparsityEstimateTest(SparsityEstimator estim) {
+ MatrixBlock m1 = MatrixBlock.randOperations(m, k, sparsity[0],
1, 1, "uniform", 1);
+ MatrixBlock m2 = MatrixBlock.randOperations(k, n, sparsity[1],
1, 1, "uniform", 2);
+ MatrixBlock m3 = MatrixBlock.randOperations(n, n2, sparsity[2],
1, 1, "uniform", 3);
MatrixBlock m4 = m1.aggregateBinaryOperations(m1, m2,
new MatrixBlock(),
InstructionUtils.getMatMultOperator(1));
MatrixBlock m5 = m4.aggregateBinaryOperations(m4, m3,
@@ -160,7 +136,7 @@ public class SquaredProductChainTest extends
AutomatedTestBase
double est = estim.estim(new MMNode(new MMNode(new MMNode(m1),
new MMNode(m2),
OpCode.MM), new MMNode(m3), OpCode.MM)).getSparsity();
TestUtils.compareScalars(est, m5.getSparsity(),
- (estim instanceof EstimatorBitsetMM) ? eps3 : //exact
- (estim instanceof EstimatorBasicWorst) ? eps1 : eps2);
+ (estim instanceof EstimatorBitsetMM) ? 0 : //exact
+ (estim instanceof EstimatorBasicWorst) ? 1.0 : 1e-4);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java
b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java
index 2a898f9c39..d117b98c1c 100644
---
a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java
+++
b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductTest.java
@@ -19,12 +19,12 @@
package org.apache.sysds.test.component.estim;
-import org.junit.Test;
import org.apache.sysds.hops.estim.EstimatorBasicAvg;
import org.apache.sysds.hops.estim.EstimatorBasicWorst;
import org.apache.sysds.hops.estim.EstimatorBitsetMM;
import org.apache.sysds.hops.estim.EstimatorDensityMap;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
+import org.apache.sysds.hops.estim.EstimatorRowWise;
import org.apache.sysds.hops.estim.EstimatorLayeredGraph;
import org.apache.sysds.hops.estim.EstimatorSample;
import org.apache.sysds.hops.estim.SparsityEstimator;
@@ -33,138 +33,108 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
/**
* This is a basic sanity check for all estimator, which need
* to compute a reasonable estimate for uniform data.
*/
-public class SquaredProductTest extends AutomatedTestBase
+@RunWith(value = Parameterized.class)
+public class SquaredProductTest extends AutomatedTestBase
{
- private final static int m = 1000;
- private final static int k = 1000;
- private final static int n = 1000;
- private final static double[] case1 = new double[]{0.0001, 0.00007};
- private final static double[] case2 = new double[]{0.0006, 0.00007};
-
- private final static double eps1 = 0.05;
- private final static double eps2 = 1e-4;
- private final static double eps3 = 0;
-
+ @Parameterized.Parameter(0)
+ public int m;
+ @Parameterized.Parameter(1)
+ public int k;
+ @Parameterized.Parameter(2)
+ public int n;
+ @Parameterized.Parameter(3)
+ public double[] sparsity;
@Override
public void setUp() {
//do nothing
}
-
- @Test
- public void testBasicAvgCase1() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n,
case1);
- }
-
- @Test
- public void testBasicAvgCase2() {
- runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n,
case2);
- }
-
- @Test
- public void testBasicWorstCase1() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n,
case1);
- }
-
- @Test
- public void testBasicWorstCase2() {
- runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n,
case2);
- }
-
- @Test
- public void testDensityMapCase1() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n,
case1);
- }
-
- @Test
- public void testDensityMapCase2() {
- runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n,
case2);
- }
-
- @Test
- public void testDensityMap7Case1() {
- runSparsityEstimateTest(new EstimatorDensityMap(7), m, k, n,
case1);
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ // {m, k, n, sparsity}
+ {1000, 1000, 1000, new double[]{0.0001, 0.00007}},
+ {1000, 1000, 1000, new double[]{0.0006, 0.00007}},
+ });
}
@Test
- public void testDensityMap7Case2() {
- runSparsityEstimateTest(new EstimatorDensityMap(7), m, k, n,
case2);
+ public void testBasicAvg() {
+ runSparsityEstimateTest(new EstimatorBasicAvg());
}
@Test
- public void testBitsetMatrixCase1() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n,
case1);
+ public void testBasicWorst() {
+ runSparsityEstimateTest(new EstimatorBasicWorst());
}
@Test
- public void testBitsetMatrixCase2() {
- runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n,
case2);
+ public void testDensityMap() {
+ runSparsityEstimateTest(new EstimatorDensityMap());
}
@Test
- public void testMatrixHistogramCase1() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m,
k, n, case1);
+ public void testDensityMapBlocksize7() {
+ runSparsityEstimateTest(new EstimatorDensityMap(7));
}
@Test
- public void testMatrixHistogramCase2() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m,
k, n, case2);
+ public void testBitsetMatrix() {
+ runSparsityEstimateTest(new EstimatorBitsetMM());
}
@Test
- public void testMatrixHistogramExceptCase1() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m,
k, n, case1);
+ public void testMatrixHistogram() {
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(false));
}
@Test
- public void testMatrixHistogramExceptCase2() {
- runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m,
k, n, case2);
+ public void testMatrixHistogramExcept() {
+ runSparsityEstimateTest(new EstimatorMatrixHistogram(true));
}
@Test
- public void testSamplingDefCase1() {
- runSparsityEstimateTest(new EstimatorSample(), m, k, n, case1);
+ public void testSampling() {
+ runSparsityEstimateTest(new EstimatorSample());
}
@Test
- public void testSamplingDefCase2() {
- runSparsityEstimateTest(new EstimatorSample(), m, k, n, case2);
- }
-
- @Test
- public void testSampling20Case1() {
- runSparsityEstimateTest(new EstimatorSample(0.2), m, k, n,
case1);
- }
-
- @Test
- public void testSampling20Case2() {
- runSparsityEstimateTest(new EstimatorSample(0.2), m, k, n,
case2);
+ public void testSamplingFrac20() {
+ runSparsityEstimateTest(new EstimatorSample(0.2));
}
@Test
- public void testLayeredGraphCase1() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n,
case1);
+ public void testLayeredGraph() {
+ runSparsityEstimateTest(new EstimatorLayeredGraph());
}
@Test
- public void testLayeredGraphCase2() {
- runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n,
case2);
+ public void testRowWise() {
+ runSparsityEstimateTest(new EstimatorRowWise());
}
-
- private static void runSparsityEstimateTest(SparsityEstimator estim,
int m, int k, int n, double[] sp) {
- MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1,
"uniform", 3);
- MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1,
"uniform", 7);
+
+ private void runSparsityEstimateTest(SparsityEstimator estim) {
+ MatrixBlock m1 = MatrixBlock.randOperations(m, k, sparsity[0],
1, 1, "uniform", 3);
+ MatrixBlock m2 = MatrixBlock.randOperations(k, n, sparsity[1],
1, 1, "uniform", 7);
MatrixBlock m3 = m1.aggregateBinaryOperations(m1, m2,
new MatrixBlock(),
InstructionUtils.getMatMultOperator(1));
//compare estimated and real sparsity
double est = estim.estim(m1, m2);
TestUtils.compareScalars(est, m3.getSparsity(),
- (estim instanceof EstimatorBitsetMM) ? eps3 : //exact
- (estim instanceof EstimatorBasicWorst) ? eps1 : eps2);
+ (estim instanceof EstimatorBitsetMM) ? 0 : //exact
+ (estim instanceof EstimatorBasicWorst) ? 0.05 : 1e-4);
}
}