Repository: systemml
Updated Branches:
  refs/heads/master 78bfb7712 -> e7d948f9c


[SYSTEMML-2288,2295] Fix estimates mm chains (density maps, bitsets)

This patch fixes the sparsity estimation logic for matrix multiply
chains, specifically for the estimators based on density maps and
bitsets. Additionally, this also includes related test cases.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d17a2e22
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d17a2e22
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d17a2e22

Branch: refs/heads/master
Commit: d17a2e22917bd04b25ed0ad7e050f869cb1da92b
Parents: 78bfb77
Author: Matthias Boehm <[email protected]>
Authored: Wed May 2 15:55:32 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed May 2 15:55:32 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/estim/EstimatorBitsetMM.java     |   2 +-
 .../sysml/hops/estim/EstimatorDensityMap.java   |   4 +-
 .../org/apache/sysml/hops/estim/MMNode.java     |  12 +-
 .../estim/SquaredProductChainTest.java          | 134 +++++++++++++++++++
 .../functions/estim/ZPackageSuite.java          |   2 +
 5 files changed, 148 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d17a2e22/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
index c90fbfc..8bb5805 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
@@ -45,7 +45,7 @@ public class EstimatorBitsetMM extends SparsityEstimator
                BitsetMatrix m1Map = !root.getLeft().isLeaf() ?
                        (BitsetMatrix)root.getLeft().getSynopsis() : new 
BitsetMatrix(root.getLeft().getData());
                BitsetMatrix m2Map = !root.getRight().isLeaf() ?
-                       (BitsetMatrix)root.getLeft().getSynopsis() : new 
BitsetMatrix(root.getLeft().getData());
+                       (BitsetMatrix)root.getRight().getSynopsis() : new 
BitsetMatrix(root.getRight().getData());
                
                //estimate output density map and sparsity via boolean matrix 
mult
                BitsetMatrix outMap = m1Map.matMult(m2Map);

http://git-wip-us.apache.org/repos/asf/systemml/blob/d17a2e22/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
index 1883d59..5244cad 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
@@ -61,11 +61,11 @@ public class EstimatorDensityMap extends SparsityEstimator
                MatrixBlock m1Map = !root.getLeft().isLeaf() ?
                        (MatrixBlock)root.getLeft().getSynopsis() : 
computeDensityMap(root.getLeft().getData());
                MatrixBlock m2Map = !root.getRight().isLeaf() ?
-                       (MatrixBlock)root.getLeft().getSynopsis() : 
computeDensityMap(root.getLeft().getData());
+                       (MatrixBlock)root.getRight().getSynopsis() : 
computeDensityMap(root.getRight().getData());
                
                //estimate output density map and sparsity
                MatrixBlock outMap = estimIntern(m1Map, m2Map,
-                       true, root.getRows(), root.getLeft().getCols(), 
root.getCols());
+                       false, root.getRows(), root.getLeft().getCols(), 
root.getCols());
                root.setSynopsis(outMap); //memoize density map
                return OptimizerUtils.getSparsity( //aggregate output histogram
                        root.getRows(), root.getCols(), (long)outMap.sum());

http://git-wip-us.apache.org/repos/asf/systemml/blob/d17a2e22/src/main/java/org/apache/sysml/hops/estim/MMNode.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/MMNode.java 
b/src/main/java/org/apache/sysml/hops/estim/MMNode.java
index 2209943..55aee3d 100644
--- a/src/main/java/org/apache/sysml/hops/estim/MMNode.java
+++ b/src/main/java/org/apache/sysml/hops/estim/MMNode.java
@@ -34,13 +34,19 @@ public class MMNode
        private final MatrixCharacteristics _mc;
        private Object _synops = null;
        
+       public MMNode(MatrixBlock in) {
+               _m1 = null;
+               _m2 = null;
+               _data = in;
+               _mc = in.getMatrixCharacteristics();
+       }
+       
        public MMNode(MMNode left, MMNode right) {
                _m1 = left;
                _m2 = right;
                _data = null;
-               _mc = isLeaf() ? _data.getMatrixCharacteristics() :
-                       new MatrixCharacteristics(_data.getNumRows(),
-                       _data.getNumColumns(), -1, -1);
+               _mc = new MatrixCharacteristics(
+                       _m1.getRows(), _m2.getCols(), -1, -1);
        }
        
        public int getRows() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d17a2e22/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java
new file mode 100644
index 0000000..82e45ed
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.estim;
+
+import org.junit.Test;
+import org.apache.sysml.hops.estim.EstimatorBasicAvg;
+import org.apache.sysml.hops.estim.EstimatorBasicWorst;
+import org.apache.sysml.hops.estim.EstimatorBitsetMM;
+import org.apache.sysml.hops.estim.EstimatorDensityMap;
+import org.apache.sysml.hops.estim.EstimatorMatrixHistogram;
+import org.apache.sysml.hops.estim.MMNode;
+import org.apache.sysml.hops.estim.SparsityEstimator;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.utils.TestUtils;
+
+/**
+ * This is a basic sanity check for all estimator, which need
+ * to compute a reasonable estimate for uniform data.
+ */
+public class SquaredProductChainTest extends AutomatedTestBase 
+{
+       private final static int m = 1000;
+       private final static int k = 1000;
+       private final static int n = 1000;
+       private final static int n2 = 1000;
+       private final static double[] case1 = new double[]{0.0001, 0.00007, 
0.001};
+       private final static double[] case2 = new double[]{0.0006, 0.00007, 
0.001};
+
+       private final static double eps1 = 1.0;
+       private final static double eps2 = 1e-4;
+       private final static double eps3 = 0;
+       
+       
+       @Override
+       public void setUp() {
+               //do  nothing
+       }
+       
+       @Test
+       public void testBasicAvgCase1() {
+               runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n, n2, 
case1);
+       }
+       
+       @Test
+       public void testBasicAvgCase2() {
+               runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n, n2, 
case2);
+       }
+       
+       @Test
+       public void testBasicWorstCase1() {
+               runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n, n2, 
case1);
+       }
+       
+       @Test
+       public void testBasicWorstCase2() {
+               runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n, n2, 
case2);
+       }
+       
+       @Test
+       public void testDensityMapCase1() {
+               runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n, n2, 
case1);
+       }
+       
+       @Test
+       public void testDensityMapCase2() {
+               runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n, n2, 
case2);
+       }
+       
+       @Test
+       public void testDensityMap7Case1() {
+               runSparsityEstimateTest(new EstimatorDensityMap(7), m, k, n, 
n2, case1);
+       }
+       
+       @Test
+       public void testDensityMap7Case2() {
+               runSparsityEstimateTest(new EstimatorDensityMap(7), m, k, n, 
n2, case2);
+       }
+       
+       @Test
+       public void testBitsetMatrixCase1() {
+               runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n, n2, 
case1);
+       }
+       
+       @Test
+       public void testBitsetMatrixCase2() {
+               runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n, n2, 
case2);
+       }
+       
+       @Test
+       public void testMatrixHistogramCase1() {
+               runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k, 
n, n2, case1);
+       }
+       
+       @Test
+       public void testMatrixHistogramCase2() {
+               runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k, 
n, n2, case2);
+       }
+       
+       private void runSparsityEstimateTest(SparsityEstimator estim, int m, 
int k, int n, int n2, double[] sp) {
+               MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, 
"uniform", 1);
+               MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1, 
"uniform", 2);
+               MatrixBlock m3 = MatrixBlock.randOperations(n, n2, sp[2], 1, 1, 
"uniform", 3);
+               MatrixBlock m4 = m1.aggregateBinaryOperations(m1, m2, 
+                       new MatrixBlock(), 
InstructionUtils.getMatMultOperator(1));
+               MatrixBlock m5 = m1.aggregateBinaryOperations(m4, m3, 
+                       new MatrixBlock(), 
InstructionUtils.getMatMultOperator(1));
+               
+               //compare estimated and real sparsity
+               double est = estim.estim(new MMNode(
+                       new MMNode(new MMNode(m1), new MMNode(m2)), new 
MMNode(m3)));
+               TestUtils.compareScalars(est, m5.getSparsity(),
+                       (estim instanceof EstimatorBitsetMM) ? eps3 : //exact
+                       (estim instanceof EstimatorBasicWorst) ? eps1 : eps2);
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/d17a2e22/src/test_suites/java/org/apache/sysml/test/integration/functions/estim/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/estim/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/estim/ZPackageSuite.java
index 1693f18..82d1891 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/estim/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/estim/ZPackageSuite.java
@@ -27,6 +27,8 @@ import org.junit.runners.Suite;
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
        OuterProductTest.class,
+       SquaredProductChainTest.class,
+       SquaredProductTest.class,
 })
 
 

Reply via email to