This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit b537181802f552f89014c282b681fea7c06ef404
Author: baunsgaard <[email protected]>
AuthorDate: Mon Sep 20 18:42:02 2021 +0200

    [SYSTEMDS-2610] CLA Updates
    
    - CLA Decompression exploiting common value
    - Use compression ratio to multiply with cost of some operations
    - lower sample ratio
    - If the sample size is very small double it
    - Workload tree corrections
    - Set a max sample size in compression settings.
    - Add Hybrid co-coding strategy to use both Que and AllCompare coCoding
---
 src/main/java/org/apache/sysds/conf/DMLConfig.java |   2 +-
 .../runtime/compress/CompressedMatrixBlock.java    | 100 +---------
 .../compress/CompressedMatrixBlockFactory.java     |   1 +
 .../runtime/compress/CompressionSettings.java      |   8 +-
 .../compress/CompressionSettingsBuilder.java       |  21 +-
 .../runtime/compress/cocode/CoCodeGreedy.java      |  33 ++--
 .../runtime/compress/cocode/CoCodeHybrid.java      |  56 ++++++
 .../runtime/compress/cocode/CoCodePriorityQue.java |  32 +--
 .../runtime/compress/cocode/CoCoderFactory.java    |   5 +-
 .../runtime/compress/colgroup/ColGroupSDC.java     |   8 +-
 .../compress/colgroup/ColGroupSDCSingle.java       |   9 +-
 .../compress/colgroup/dictionary/ADictionary.java  |   4 +-
 .../compress/colgroup/dictionary/Dictionary.java   |  11 +-
 .../compress/cost/ComputationCostEstimator.java    |  68 ++++---
 .../estim/CompressedSizeEstimatorFactory.java      |  27 ++-
 .../runtime/compress/lib/CLALibBinaryCellOp.java   |  16 +-
 .../runtime/compress/lib/CLALibDecompress.java     | 218 +++++++++++++++++++++
 .../runtime/compress/lib/CLALibLeftMultBy.java     |  44 +----
 .../sysds/runtime/compress/lib/CLALibScalar.java   |   5 +-
 .../sysds/runtime/compress/lib/CLALibUtils.java    |  79 ++++++++
 .../compress/workload/WorkloadAnalyzer.java        |  58 +++---
 .../component/compress/workload/WorkloadTest.java  |   2 +-
 .../compress/configuration/CompressForce.java      |   2 +-
 23 files changed, 566 insertions(+), 243 deletions(-)

diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java 
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index db59505..a59101b 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -131,7 +131,7 @@ public class DMLConfig
                _defaultVals.put(COMPRESSED_LOSSY,       "false" );
                _defaultVals.put(COMPRESSED_VALID_COMPRESSIONS, "SDC,DDC");
                _defaultVals.put(COMPRESSED_OVERLAPPING, "true" );
-               _defaultVals.put(COMPRESSED_SAMPLING_RATIO, "0.05");
+               _defaultVals.put(COMPRESSED_SAMPLING_RATIO, "0.02");
                _defaultVals.put(COMPRESSED_COCODE,      "AUTO");
                _defaultVals.put(COMPRESSED_COST_MODEL,  "AUTO");
                _defaultVals.put(COMPRESSED_TRANSPOSE,   "auto");
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index bff9a21..4ff5bcb 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -27,12 +27,8 @@ import java.io.ObjectOutput;
 import java.lang.ref.SoftReference;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Comparator;
 import java.util.Iterator;
 import java.util.List;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 
 import org.apache.commons.lang.NotImplementedException;
@@ -56,6 +52,7 @@ import 
org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
 import org.apache.sysds.runtime.compress.lib.CLALibAppend;
 import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
 import org.apache.sysds.runtime.compress.lib.CLALibCompAgg;
+import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
 import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
 import org.apache.sysds.runtime.compress.lib.CLALibReExpand;
 import org.apache.sysds.runtime.compress.lib.CLALibRightMultBy;
@@ -104,7 +101,6 @@ import 
org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
 import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
 import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
-import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.IndexRange;
 import org.apache.sysds.utils.DMLCompressionStatistics;
 
@@ -250,15 +246,10 @@ public class CompressedMatrixBlock extends MatrixBlock {
 
                ret.allocateDenseBlock();
 
-               if(isOverlapping()){
-                       Comparator<AColGroup> comp = Comparator.comparing(x -> 
effect(x));
-                       _colGroups.sort(comp);
-               }
-
                if(k == 1)
-                       decompress(ret);
+                       CLALibDecompress.decompress(ret, getColGroups(), 
nonZeros, isOverlapping());
                else
-                       decompress(ret, k);
+                       CLALibDecompress.decompress(ret, getColGroups(), 
isOverlapping(), k);
 
                if(this.isOverlapping())
                        ret.recomputeNonZeros();
@@ -275,47 +266,6 @@ public class CompressedMatrixBlock extends MatrixBlock {
                return ret;
        }
 
-       private double effect(AColGroup x){
-               return - Math.max(x.getMax(), Math.abs(x.getMin()));
-       }
-
-       private MatrixBlock decompress(MatrixBlock ret) {
-
-               ret.setNonZeros(nonZeros == -1 && !this.isOverlapping() ? 
recomputeNonZeros() : nonZeros);
-               final int block = (int) Math.ceil((double) 
(CompressionSettings.BITMAP_BLOCK_SZ) / getNumColumns());
-               final int blklen = block > 1000 ? block + 1000 - block % 1000 : 
Math.max(64, block);
-               for(int i = 0; i < getNumRows(); i += blklen)
-                       for(AColGroup grp : _colGroups)
-                               grp.decompressToBlockUnSafe(ret, i, Math.min(i 
+ blklen, rlen));
-
-               return ret;
-       }
-
-       private MatrixBlock decompress(MatrixBlock ret, int k) {
-               try {
-                       final ExecutorService pool = CommonThreadPool.get(k);
-                       final int rlen = getNumRows();
-                       final int block = (int) Math.ceil((double) 
(CompressionSettings.BITMAP_BLOCK_SZ) / getNumColumns());
-                       final int blklen = block > 1000 ? block + 1000 - block 
% 1000 : Math.max(64, block);
-                       final ArrayList<DecompressTask> tasks = new 
ArrayList<>();
-                       for(int i = 0; i * blklen < getNumRows(); i++)
-                               tasks.add(new DecompressTask(_colGroups, ret, i 
* blklen, Math.min((i + 1) * blklen, rlen),
-                                       overlappingColGroups));
-                       List<Future<Long>> rtasks = pool.invokeAll(tasks);
-                       pool.shutdown();
-
-                       long nnz = 0;
-                       for(Future<Long> rt : rtasks)
-                               nnz += rt.get();
-                       ret.setNonZeros(nnz);
-               }
-               catch(InterruptedException | ExecutionException ex) {
-                       throw new DMLCompressionException("Parallel 
decompression failed", ex);
-               }
-
-               return ret;
-       }
-
        /**
         * Get the cached decompressed matrix (if it exists otherwise null).
         * 
@@ -673,10 +623,10 @@ public class CompressedMatrixBlock extends MatrixBlock {
                        ReorgOperator r_op = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads());
                        ret = ret.reorgOperations(r_op, new MatrixBlock(), 0, 
0, 0);
                }
-               
+
                if(ret.getNumRows() == 0 || ret.getNumColumns() == 0)
                        throw new DMLCompressionException("Error in outputted 
MM no dimensions");
-               
+
                return ret;
        }
 
@@ -788,46 +738,6 @@ public class CompressedMatrixBlock extends MatrixBlock {
                return null;
        }
 
-       private static class DecompressTask implements Callable<Long> {
-               private final List<AColGroup> _colGroups;
-               private final MatrixBlock _ret;
-               private final int _rl;
-               private final int _ru;
-               private final boolean _overlapping;
-
-               protected DecompressTask(List<AColGroup> colGroups, MatrixBlock 
ret, int rl, int ru, boolean overlapping) {
-                       _colGroups = colGroups;
-                       _ret = ret;
-                       _rl = rl;
-                       _ru = ru;
-                       _overlapping = overlapping;
-               }
-
-               @Override
-               public Long call() {
-
-                       // preallocate sparse rows to avoid repeated alloc
-                       if(!_overlapping && _ret.isInSparseFormat()) {
-                               int[] rnnz = new int[_ru - _rl];
-                               for(AColGroup grp : _colGroups)
-                                       grp.countNonZerosPerRow(rnnz, _rl, _ru);
-                               SparseBlock rows = _ret.getSparseBlock();
-                               for(int i = _rl; i < _ru; i++)
-                                       rows.allocate(i, rnnz[i - _rl]);
-                       }
-
-                       // decompress row partition
-                       for(AColGroup grp : _colGroups)
-                               grp.decompressToBlockUnSafe(_ret, _rl, _ru);
-
-                       // post processing (sort due to append)
-                       if(_ret.isInSparseFormat())
-                               _ret.sortSparseRows(_rl, _ru);
-
-                       return _overlapping ? 0 : _ret.recomputeNonZeros(_rl, 
_ru - 1);
-               }
-       }
-
        @Override
        public String toString() {
                StringBuilder sb = new StringBuilder();
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
 
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
index 98567e0..d77ab82 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
@@ -484,6 +484,7 @@ public class CompressedMatrixBlockFactory {
                                case 1:
                                        LOG.debug("--compression phase " + 
phase + " Grouping  : " + getLastTimePhase());
                                        LOG.debug("Grouping using: " + 
compSettings.columnPartitioner);
+                                       LOG.debug("Cost Calculated using: " + 
costEstimator);
                                        LOG.debug("--Cocoded Columns estimated 
Compression:" + _stats.estimatedSizeCoCoded);
                                        break;
                                case 2:
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java 
b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
index b347031..c1a9cd4 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
@@ -92,6 +92,11 @@ public class CompressionSettings {
         */
        public final int minimumSampleSize;
 
+       /**
+        * The maximum size of the sample extracted.
+        */
+       public final int maxSampleSize;
+
        /** The sample type used for sampling */
        public final EstimationType estimationType;
 
@@ -110,7 +115,7 @@ public class CompressionSettings {
 
        protected CompressionSettings(double samplingRatio, boolean 
allowSharedDictionary, String transposeInput, int seed,
                boolean lossy, EnumSet<CompressionType> validCompressions, 
boolean sortValuesByLength,
-               PartitionerType columnPartitioner, int maxColGroupCoCode, 
double coCodePercentage, int minimumSampleSize,
+               PartitionerType columnPartitioner, int maxColGroupCoCode, 
double coCodePercentage, int minimumSampleSize, int maxSampleSize,
                EstimationType estimationType, CostType costComputationType, 
double minimumCompressionRatio) {
                this.samplingRatio = samplingRatio;
                this.allowSharedDictionary = allowSharedDictionary;
@@ -123,6 +128,7 @@ public class CompressionSettings {
                this.maxColGroupCoCode = maxColGroupCoCode;
                this.coCodePercentage = coCodePercentage;
                this.minimumSampleSize = minimumSampleSize;
+               this.maxSampleSize= maxSampleSize;
                this.estimationType = estimationType;
                this.costComputationType = costComputationType;
                this.minimumCompressionRatio = minimumCompressionRatio;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
 
b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
index 156fdfc..d5fd036 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
@@ -42,6 +42,7 @@ public class CompressionSettingsBuilder {
        private int maxColGroupCoCode = 10000;
        private double coCodePercentage = 0.01;
        private int minimumSampleSize = 2000;
+       private int maxSampleSize = 1000000;
        private EstimationType estimationType = EstimationType.HassAndStokes;
        private PartitionerType columnPartitioner;
        private CostType costType;
@@ -247,6 +248,18 @@ public class CompressionSettingsBuilder {
        }
 
        /**
+        * Set the maximum sample size to extract from a given matrix, this 
overrules the sample percentage if the sample
+        * percentage extracted is higher than this maximum bound.
+        * 
+        * @param maxSampleSize The maximum sample size to extract
+        * @return The CompressionSettingsBuilder
+        */
+       public CompressionSettingsBuilder setMaxSampleSize(int maxSampleSize) {
+               this.maxSampleSize = maxSampleSize;
+               return this;
+       }
+
+       /**
         * Set the estimation type used for the sampled estimates.
         * 
         * @param estimationType the estimation type in used.
@@ -268,6 +281,12 @@ public class CompressionSettingsBuilder {
                return this;
        }
 
+       /**
+        * Set the minimum compression ratio to be achieved by the compression.
+        * 
+        * @param ratio The ratio to achieve while compressing
+        * @return The CompressionSettingsBuilder
+        */
        public CompressionSettingsBuilder setMinimumCompressionRatio(double 
ratio) {
                this.minimumCompressionRatio = ratio;
                return this;
@@ -281,6 +300,6 @@ public class CompressionSettingsBuilder {
        public CompressionSettings create() {
                return new CompressionSettings(samplingRatio, 
allowSharedDictionary, transposeInput, seed, lossy,
                        validCompressions, sortValuesByLength, 
columnPartitioner, maxColGroupCoCode, coCodePercentage,
-                       minimumSampleSize, estimationType, costType, 
minimumCompressionRatio);
+                       minimumSampleSize, maxSampleSize, estimationType, 
costType, minimumCompressionRatio);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java 
b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java
index cf366e7..55bf3b9 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java
@@ -36,25 +36,26 @@ import org.apache.sysds.runtime.compress.utils.Util;
 
 public class CoCodeGreedy extends AColumnCoCoder {
 
-
-       private final Memorizer mem;
-
        protected CoCodeGreedy(CompressedSizeEstimator sizeEstimator, 
ICostEstimate costEstimator,
                CompressionSettings cs) {
                super(sizeEstimator, costEstimator, cs);
-               mem = new Memorizer();
        }
 
        @Override
        protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, 
int k) {
-               for(CompressedSizeInfoColGroup g : colInfos.compressionInfo)
+               colInfos.setInfo(join(colInfos.compressionInfo, _sest, _cest, 
_cs));
+               return colInfos;
+       }
+
+       protected static List<CompressedSizeInfoColGroup> 
join(List<CompressedSizeInfoColGroup> inputColumns, CompressedSizeEstimator 
sEst, ICostEstimate cEst, CompressionSettings cs) {
+               Memorizer mem = new Memorizer(cs, sEst);
+               for(CompressedSizeInfoColGroup g : inputColumns)
                        mem.put(g);
                
-               colInfos.setInfo(coCodeBruteForce(colInfos.compressionInfo));
-               return colInfos;
+               return coCodeBruteForce(inputColumns, cEst, mem);
        }
 
-       private List<CompressedSizeInfoColGroup> 
coCodeBruteForce(List<CompressedSizeInfoColGroup> inputColumns) {
+       private static List<CompressedSizeInfoColGroup> 
coCodeBruteForce(List<CompressedSizeInfoColGroup> inputColumns, ICostEstimate 
cEst, Memorizer mem) {
 
                List<ColIndexes> workset = new ArrayList<>(inputColumns.size());
 
@@ -69,8 +70,8 @@ public class CoCodeGreedy extends AColumnCoCoder {
                                for(int j = i + 1; j < workset.size(); j++) {
                                        final ColIndexes c1 = workset.get(i);
                                        final ColIndexes c2 = workset.get(j);
-                                       final double costC1 = 
_cest.getCostOfColumnGroup(mem.get(c1));
-                                       final double costC2 = 
_cest.getCostOfColumnGroup(mem.get(c2));
+                                       final double costC1 = 
cEst.getCostOfColumnGroup(mem.get(c1));
+                                       final double costC2 = 
cEst.getCostOfColumnGroup(mem.get(c2));
 
                                        mem.incst1();
                                        // pruning filter : skip dominated 
candidates
@@ -82,7 +83,7 @@ public class CoCodeGreedy extends AColumnCoCoder {
                                        // Join the two column groups.
                                        // and Memorize the new join.
                                        final CompressedSizeInfoColGroup 
c1c2Inf = mem.getOrCreate(c1, c2);
-                                       final double costC1C2 = 
_cest.getCostOfColumnGroup(c1c2Inf);
+                                       final double costC1C2 = 
cEst.getCostOfColumnGroup(c1c2Inf);
 
                                        final double newSizeChangeIfSelected = 
costC1C2 - costC1 - costC2;
 
@@ -120,11 +121,15 @@ public class CoCodeGreedy extends AColumnCoCoder {
                return ret;
        }
 
-       protected class Memorizer {
+       protected static class Memorizer {
+               private final CompressionSettings _cs;
+               private final CompressedSizeEstimator _sEst;
                private final Map<ColIndexes, CompressedSizeInfoColGroup> mem;
                private int st1 = 0, st2 = 0, st3 = 0, st4 = 0;
 
-               public Memorizer() {
+               public Memorizer(CompressionSettings cs, 
CompressedSizeEstimator sEst) {
+                       _cs = cs;
+                       _sEst = sEst;
                        mem = new HashMap<>();
                }
 
@@ -159,7 +164,7 @@ public class CoCodeGreedy extends AColumnCoCoder {
                                        g = 
CompressedSizeInfoColGroup.addConstGroup(c, left, _cs.validCompressions);
                                else {
                                        st3++;
-                                       g = _sest.estimateJoinCompressedSize(c, 
left, right);
+                                       g = _sEst.estimateJoinCompressedSize(c, 
left, right);
                                }
 
                                if(leftConst || rightConst)
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java 
b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java
new file mode 100644
index 0000000..f762c0d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.cocode;
+
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.cost.ICostEstimate;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
+
+/**
+ * This cocode strategy starts out with priority que until a threshold number 
of columnGroups is achieved, then the
+ * strategy shifts into a greedy all compare.
+ */
+public class CoCodeHybrid extends AColumnCoCoder {
+
+    protected CoCodeHybrid(CompressedSizeEstimator sizeEstimator, 
ICostEstimate costEstimator, CompressionSettings cs) {
+        super(sizeEstimator, costEstimator, cs);
+    }
+
+    @Override
+    protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, 
int k) {
+        final int startSize = colInfos.getInfo().size();
+        final int PriorityQueGoal = 40;
+        if(startSize > 200) {
+
+            colInfos.setInfo(CoCodePriorityQue.join(colInfos.getInfo(), _sest, 
_cest, PriorityQueGoal));
+
+            final int pqSize = colInfos.getInfo().size();
+            if(pqSize <= PriorityQueGoal)
+                colInfos.setInfo(CoCodeGreedy.join(colInfos.getInfo(), _sest, 
_cest, _cs));
+        }
+        else {
+            colInfos.setInfo(CoCodeGreedy.join(colInfos.getInfo(), _sest, 
_cest, _cs));
+        }
+
+        return colInfos;
+    }
+
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java 
b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java
index 27b678c..b53859e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java
@@ -48,12 +48,13 @@ public class CoCodePriorityQue extends AColumnCoCoder {
 
        @Override
        protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, 
int k) {
-               colInfos.setInfo(join(colInfos.getInfo()));
+               colInfos.setInfo(join(colInfos.getInfo(), _sest, _cest, 1));
                return colInfos;
        }
 
-       private List<CompressedSizeInfoColGroup> 
join(List<CompressedSizeInfoColGroup> currentGroups) {
-               Comparator<CompressedSizeInfoColGroup> comp = 
Comparator.comparing(x -> _cest.getCostOfColumnGroup(x));
+       protected static List<CompressedSizeInfoColGroup> 
join(List<CompressedSizeInfoColGroup> currentGroups,
+               CompressedSizeEstimator sEst, ICostEstimate cEst, int 
minNumGroups) {
+               Comparator<CompressedSizeInfoColGroup> comp = 
Comparator.comparing(x -> cEst.getCostOfColumnGroup(x));
                Queue<CompressedSizeInfoColGroup> que = new 
PriorityQueue<>(currentGroups.size(), comp);
                List<CompressedSizeInfoColGroup> ret = new ArrayList<>();
 
@@ -62,15 +63,16 @@ public class CoCodePriorityQue extends AColumnCoCoder {
                                que.add(g);
 
                CompressedSizeInfoColGroup l = null;
-               if(_cest.isCompareAll()) {
-                       double costBeforeJoin = 
_cest.getCostOfCollectionOfGroups(que);
+               if(cEst.isCompareAll()) {
+                       double costBeforeJoin = 
cEst.getCostOfCollectionOfGroups(que);
                        l = que.poll();
-                       while(que.peek() != null) {
+                       int groupNr = ret.size() + que.size();
+                       while(que.peek() != null && groupNr >= minNumGroups) {
 
                                CompressedSizeInfoColGroup r = que.poll();
-                               final CompressedSizeInfoColGroup g = 
_sest.estimateJoinCompressedSize(l, r);
+                               final CompressedSizeInfoColGroup g = 
sEst.estimateJoinCompressedSize(l, r);
                                if(g != null) {
-                                       final double costOfJoin = 
_cest.getCostOfCollectionOfGroups(que, g);
+                                       final double costOfJoin = 
cEst.getCostOfCollectionOfGroups(que, g);
                                        if(costOfJoin < costBeforeJoin) {
                                                costBeforeJoin = costOfJoin;
                                                que.add(g);
@@ -86,17 +88,19 @@ public class CoCodePriorityQue extends AColumnCoCoder {
                                }
 
                                l = que.poll();
+                               groupNr = ret.size() + que.size();
                        }
                }
                else {
                        l = que.poll();
-                       while(que.peek() != null) {
+                       int groupNr = ret.size() + que.size();
+                       while(que.peek() != null && groupNr >= minNumGroups) {
                                CompressedSizeInfoColGroup r = que.peek();
-                               if(_cest.shouldTryJoin(l, r)) {
-                                       CompressedSizeInfoColGroup g = 
_sest.estimateJoinCompressedSize(l, r);
+                               if(cEst.shouldTryJoin(l, r)) {
+                                       CompressedSizeInfoColGroup g = 
sEst.estimateJoinCompressedSize(l, r);
                                        if(g != null) {
-                                               double costOfJoin = 
_cest.getCostOfColumnGroup(g);
-                                               double costIndividual = 
_cest.getCostOfColumnGroup(l) + _cest.getCostOfColumnGroup(r);
+                                               double costOfJoin = 
cEst.getCostOfColumnGroup(g);
+                                               double costIndividual = 
cEst.getCostOfColumnGroup(l) + cEst.getCostOfColumnGroup(r);
 
                                                if(costOfJoin < costIndividual) 
{
                                                        que.poll();
@@ -112,8 +116,10 @@ public class CoCodePriorityQue extends AColumnCoCoder {
                                        ret.add(l);
 
                                l = que.poll();
+                               groupNr = ret.size() + que.size();
                        }
                }
+
                if(l != null)
                        ret.add(l);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
index c467f70..eaf9cb4 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
@@ -63,10 +63,7 @@ public class CoCoderFactory {
                                // TODO make decision better depending on how 
much time is allocated for the compression
                                // for instance if the compressed object is 
used for a million instructions, it might be good to
                                // search for a really good compression even if 
it take longer.
-                               if(est.getNumColumns() > 200)
-                                       return new CoCodePriorityQue(est, 
costEstimator, cs);
-                               else
-                                       return new CoCodeGreedy(est, 
costEstimator, cs);
+                               return new CoCodeHybrid(est, costEstimator, cs);
                        case GREEDY:
                                return new CoCodeGreedy(est, costEstimator, cs);
                        case BIN_PACKING:
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index 5c6d365..fb93ed6 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -558,11 +558,13 @@ public class ColGroupSDC extends ColGroupValue {
        public ColGroupSDCZeros extractCommon(double[] constV) {
                double[] commonV = _dict.getTuple(getNumValues() - 1, 
_colIndexes.length);
 
-               for(int i = 0; i < _colIndexes.length; i++) {
+               if(commonV == null) // The common tuple was all zero. Therefore 
this column group should never have been SDC.
+                       return new ColGroupSDCZeros(_colIndexes, _numRows, 
_dict, _indexes, _data, getCounts());
+
+               for(int i = 0; i < _colIndexes.length; i++)
                        constV[_colIndexes[i]] += commonV[i];
-               }
+
                ADictionary subtractedDict = _dict.subtractTuple(commonV);
                return new ColGroupSDCZeros(_colIndexes, _numRows, 
subtractedDict, _indexes, _data, getCounts());
        }
-
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
index de0b4a8..b2b4283 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
@@ -117,8 +117,8 @@ public class ColGroupSDCSingle extends ColGroupValue {
                        for(int j = 0; j < nCol; j++)
                                c[off + _colIndexes[j]] += 
values[offsetToDefault + j];
                }
-               
-               _indexes.cacheIterator(it, ru );
+
+               _indexes.cacheIterator(it, ru);
        }
 
        @Override
@@ -473,11 +473,14 @@ public class ColGroupSDCSingle extends ColGroupValue {
        public ColGroupSDCSingleZeros extractCommon(double[] constV) {
                double[] commonV = _dict.getTuple(getNumValues() - 1, 
_colIndexes.length);
 
+               if(commonV == null) // The common tuple was all zero. Therefore 
this column group should never have been SDC.
+                       return new ColGroupSDCSingleZeros(_colIndexes, 
_numRows, _dict, _indexes, getCachedCounts());
+
                for(int i = 0; i < _colIndexes.length; i++)
                        constV[_colIndexes[i]] += commonV[i];
 
                ADictionary subtractedDict = _dict.subtractTuple(commonV);
-               return new ColGroupSDCSingleZeros(_colIndexes, _numRows, 
subtractedDict, _indexes, getCounts());
+               return new ColGroupSDCSingleZeros(_colIndexes, _numRows, 
subtractedDict, _indexes, getCachedCounts());
        }
 
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index 15c74b0..df75268 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -35,7 +35,7 @@ import 
org.apache.sysds.runtime.matrix.operators.ScalarOperator;
 public abstract class ADictionary implements Serializable {
 
        private static final long serialVersionUID = 9118692576356558592L;
-       
+
        protected static final Log LOG = 
LogFactory.getLog(ADictionary.class.getName());
 
        /**
@@ -342,6 +342,8 @@ public abstract class ADictionary implements Serializable {
        /**
         * Get the values contained in a specific tuple of the dictionary.
         * 
+        * If the entire row is zero return null.
+        * 
         * @param index The index where the values are located
         * @param nCol  The number of columns contained in this dictionary
         * @return a materialized double array containing the tuple.
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
index 5b9d834..0b65ed1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
@@ -43,7 +43,7 @@ import org.apache.sysds.utils.MemoryEstimates;
 public class Dictionary extends ADictionary {
 
        private static final long serialVersionUID = -6517136537249507753L;
-       
+
        private final double[] _values;
 
        public Dictionary(double[] values) {
@@ -329,14 +329,13 @@ public class Dictionary extends ADictionary {
                if(colIndexes == 1)
                        sb.append(Arrays.toString(_values));
                else {
-                       sb.append("[\n");
+                       sb.append("[\n\t");
                        for(int i = 0; i < _values.length - 1; i++) {
                                sb.append(_values[i]);
-                               sb.append((i) % (colIndexes) == colIndexes - 1 
? "\nt" + i + ": " : ", ");
+                               sb.append((i) % (colIndexes) == colIndexes - 1 
? "\n\t" : ", ");
                        }
                        sb.append(_values[_values.length - 1]);
-
-                       sb.append("\n]");
+                       sb.append("]");
                }
                return sb.toString();
        }
@@ -476,7 +475,7 @@ public class Dictionary extends ADictionary {
        }
 
        @Override
-       public Dictionary preaggValuesFromDense(int numVals, int[] colIndexes, 
int[] aggregateColumns, double[] b, 
+       public Dictionary preaggValuesFromDense(int numVals, int[] colIndexes, 
int[] aggregateColumns, double[] b,
                int cut) {
                double[] ret = new double[numVals * aggregateColumns.length];
                for(int k = 0, off = 0;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
 
b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
index 1653f51..07cf58b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
@@ -23,6 +23,7 @@ import java.util.Collection;
 
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
 public class ComputationCostEstimator implements ICostEstimate {
 
@@ -41,7 +42,6 @@ public class ComputationCostEstimator implements 
ICostEstimate {
        private final int _leftMultiplications;
        private final int _rightMultiplications;
        private final int _compressedMultiplication;
-       // private final int _rowBasedOps;
        private final int _dictionaryOps;
 
        private final boolean _isDensifying;
@@ -100,16 +100,21 @@ public class ComputationCostEstimator implements 
ICostEstimate {
                cost += _scans * scanCost(g);
                cost += _decompressions * decompressionCost(g);
                cost += _overlappingDecompressions * 
overlappingDecompressionCost(g);
-               // 16 is assuming that the left side is 16 rows.
-               double lmc = leftMultCost(g) * 16;
-               cost += _leftMultiplications * lmc;
-               // 16 is assuming that the right side is 16 rows.
-               double rmc = rightMultCost(g) * 16;
-               cost += _rightMultiplications * rmc;
-
-               // cost += _compressedMultiplication * (lmc + rmc);
-               cost += _compressedMultiplication * _compressedMultCost(g);
+               // 16 is assuming that the left / right side is 16 rows/cols.
+               final int rowsCols = 16;
+               cost += _leftMultiplications *  leftMultCost(g) * rowsCols;
+               cost += _rightMultiplications * rightMultCost(g) * rowsCols;
                cost += _dictionaryOps * dictionaryOpsCost(g);
+               
+               double size = g.getMinSize();
+               final double compressionRatio =  size / 
MatrixBlock.estimateSizeDenseInMemory(_nRows, _nCols) / g.getColumns().length;
+
+               cost *=  0.001 + compressionRatio;
+
+               cost += _compressedMultiplication * _compressedMultCost(g) * 
rowsCols;
+
+               // double uncompressedSize = 
g.getCompressionSize(CompressionType.UNCOMPRESSED);
+
                return cost;
        }
 
@@ -118,13 +123,14 @@ public class ComputationCostEstimator implements 
ICostEstimate {
        }
 
        private double leftMultCost(CompressedSizeInfoColGroup g) {
-               final int nCols = g.getColumns().length;
-               final double preAggregateCost = _nRows;
+               final int nColsInGroup = g.getColumns().length;
+               final double mcf = g.getMostCommonFraction();
+               final double preAggregateCost = mcf > 0.6 ? _nRows * (1 - 0.4 * 
mcf) : _nRows;
 
                final double numberTuples = g.getNumVals();
                final double tupleSparsity = g.getTupleSparsity();
-               final double postScalingCost = (nCols > 1 && tupleSparsity > 
0.4) ? numberTuples * nCols * tupleSparsity *
-                       1.4 : numberTuples * nCols;
+               final double postScalingCost = (nColsInGroup > 1 && 
tupleSparsity > 0.4) ? numberTuples * nColsInGroup * tupleSparsity *
+                       1.4 : numberTuples * nColsInGroup;
                if(numberTuples < 64000)
                        return preAggregateCost + postScalingCost;
                else
@@ -134,10 +140,11 @@ public class ComputationCostEstimator implements 
ICostEstimate {
 
        private double _compressedMultCost(CompressedSizeInfoColGroup g) {
                final int nColsInGroup = g.getColumns().length;
-               final double mcf = g.getMostCommonFraction();
-               final double preAggregateCost = mcf > 0.6 ? _nRows * (1 - 0.7 * 
mcf) : _nRows;
+               // final double mcf = g.getMostCommonFraction();
+               // final double preAggregateCost = (mcf > 0.6 ? _nRows * (1 - 
0.6 * mcf) : _nRows) * 4;
+               final double preAggregateCost = _nRows;
 
-               final double numberTuples = (float) g.getNumVals();
+               final double numberTuples = g.getNumVals();
                final double tupleSparsity = g.getTupleSparsity();
                final double postScalingCost = (nColsInGroup > 1 && 
tupleSparsity > 0.4) ? numberTuples * nColsInGroup * tupleSparsity *
                        1.4 : numberTuples * nColsInGroup;
@@ -163,7 +170,10 @@ public class ComputationCostEstimator implements 
ICostEstimate {
        }
 
        private double overlappingDecompressionCost(CompressedSizeInfoColGroup 
g) {
-               return _nRows * 16 * (g.getNumVals() / 64000 + 1);
+               final double mcf = g.getMostCommonFraction();
+               final double rowsCost = mcf > 0.6 ? _nRows * (1 - 0.6 * mcf) : 
_nRows;
+               //  Setting 64 to mark decompression as expensive.
+               return rowsCost * 16 * ((float)g.getNumVals() / 64000 + 1);
        }
 
        private static double dictionaryOpsCost(CompressedSizeInfoColGroup g) {
@@ -259,15 +269,19 @@ public class ComputationCostEstimator implements 
ICostEstimate {
        public String toString() {
                StringBuilder sb = new StringBuilder();
                sb.append(this.getClass().getSimpleName());
-               sb.append("\n");
-               sb.append(_nRows + "  ");
-               sb.append(_scans + " ");
-               sb.append(_decompressions + " ");
-               sb.append(_overlappingDecompressions + " ");
-               sb.append(_leftMultiplications + " ");
-               sb.append(_rightMultiplications + " ");
-               sb.append(_compressedMultiplication + " ");
-               sb.append(_dictionaryOps + " ");
+               sb.append("dims(");
+               sb.append(_nRows + ",");
+               sb.append(_nCols + ") ");
+               sb.append("CostVector:[");
+               sb.append(_scans + ",");
+               sb.append(_decompressions + ",");
+               sb.append(_overlappingDecompressions + ",");
+               sb.append(_leftMultiplications + ",");
+               sb.append(_rightMultiplications + ",");
+               sb.append(_compressedMultiplication + ",");
+               sb.append(_dictionaryOps + "]");
+               sb.append(" Densifying:");
+               sb.append(_isDensifying);
                return sb.toString();
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java
 
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java
index a4b1c99..a9a8e44 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java
@@ -27,7 +27,6 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
 public class CompressedSizeEstimatorFactory {
        protected static final Log LOG = 
LogFactory.getLog(CompressedSizeEstimatorFactory.class.getName());
-       private static final int maxSampleSize = 1000000;
 
        public static CompressedSizeEstimator getSizeEstimator(MatrixBlock 
data, CompressionSettings cs, int k) {
 
@@ -36,7 +35,7 @@ public class CompressedSizeEstimatorFactory {
                final int nnzRows = (int) Math.ceil(data.getNonZeros() / nCols);
 
                final double sampleRatio = cs.samplingRatio;
-               final int sampleSize = Math.min(getSampleSize(sampleRatio, 
nRows, cs.minimumSampleSize), maxSampleSize);
+               final int sampleSize = getSampleSize(sampleRatio, nRows, nCols, 
cs.minimumSampleSize, cs.maxSampleSize);
 
                if(nCols > 1000) {
                        return tryToMakeSampleEstimator(data, cs, sampleRatio, 
sampleSize / 10, nRows, nnzRows, k);
@@ -79,7 +78,27 @@ public class CompressedSizeEstimatorFactory {
                return cs.samplingRatio >= 1.0 || nRows < cs.minimumSampleSize 
|| sampleSize >= nnzRows;
        }
 
-       private static int getSampleSize(double sampleRatio, int nRows, int 
minimumSampleSize) {
-               return Math.max((int) Math.ceil(nRows * sampleRatio), 
minimumSampleSize);
+       /**
+        * This function returns the sample size to use.
+        * 
+        * The sampling is bound by the maximum sampling and the minimum 
sampling other than that a linear relation is used
+        * with the sample ratio.
+        * 
+        * Also influencing the sample size is the number of columns. If the 
number of columns is large the sample size is
+        * scaled down, this gives worse estimations of distinct items, but it 
makes sure that the compression time is more
+        * consistent.
+        * 
+        * @param sampleRatio       The sample ratio
+        * @param nRows             The number of rows
+        * @param nCols             The number of columns
+        * @param minimumSampleSize the minimum sample size
+        * @param maxSampleSize     the maximum sample size
+        * @return The sample size to use.
+        */
+       private static int getSampleSize(double sampleRatio, int nRows, int 
nCols, int minSampleSize, int maxSampleSize) {
+               int sampleSize = (int) Math.ceil(nRows * sampleRatio / 
Math.max(1, (double)nCols / 150));
+               if(sampleSize < 20000)
+                       sampleSize *= 2;
+               return Math.min(Math.max(sampleSize, minSampleSize), 
maxSampleSize);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
index e330ab8..b9b64fa 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
@@ -105,7 +105,8 @@ public class CLALibBinaryCellOp {
                                result = 
CompressedMatrixBlockFactory.createConstant(m1.getNumRows(), 
m1.getNumColumns(), 0);
                        else if(fn instanceof Minus1Multiply)
                                result = 
CompressedMatrixBlockFactory.createConstant(m1.getNumRows(), 
m1.getNumColumns(), 1);
-                       else if(fn instanceof Minus || fn instanceof Plus || fn 
instanceof MinusMultiply || fn instanceof PlusMultiply){
+                       else if(fn instanceof Minus || fn instanceof Plus || fn 
instanceof MinusMultiply ||
+                               fn instanceof PlusMultiply) {
                                CompressedMatrixBlock ret = new 
CompressedMatrixBlock();
                                ret.copy(m1);
                                return ret;
@@ -132,7 +133,11 @@ public class CLALibBinaryCellOp {
                // TODO optimize to allow for sparse outputs.
                final int outCells = outRows * outCols;
                if(atype == BinaryAccessType.MATRIX_COL_VECTOR) {
-                       result.reset(outRows, Math.max(outCols, 
that.getNumColumns()), outCells);
+                       if(result != null)
+                               result.reset(outRows, Math.max(outCols, 
that.getNumColumns()), outCells);
+                       else
+                               result = new MatrixBlock(outRows, 
Math.max(outCols, that.getNumColumns()), outCells);
+
                        MatrixBlock d_compressed = m1.getCachedDecompressed();
                        if(d_compressed != null) {
                                if(left)
@@ -146,12 +151,15 @@ public class CLALibBinaryCellOp {
 
                }
                else if(atype == BinaryAccessType.MATRIX_MATRIX) {
-                       result.reset(outRows, outCols, outCells);
+                       if(result != null)
+                               result.reset(outRows, outCols, outCells);
+                       else
+                               result = new MatrixBlock(outRows, outCols, 
outCells);
 
                        MatrixBlock d_compressed = m1.getCachedDecompressed();
                        if(d_compressed == null)
                                d_compressed = m1.getUncompressed("MatrixMatrix 
" + op);
-                       
+
                        if(left)
                                LibMatrixBincell.bincellOp(that, d_compressed, 
result, op);
                        else
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
new file mode 100644
index 0000000..ff2d211
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
+/**
+ * Library to decompress a list of column groups into a matrix.
+ */
+public class CLALibDecompress {
+       public static MatrixBlock decompress(MatrixBlock ret, List<AColGroup> 
groups, long nonZeros, boolean overlapping) {
+
+               final int rlen = ret.getNumRows();
+               final int clen = ret.getNumColumns();
+               final int block = (int) Math.ceil((double) 
(CompressionSettings.BITMAP_BLOCK_SZ) / clen);
+               final int blklen = block > 1000 ? block + 1000 - block % 1000 : 
Math.max(64, block);
+               final boolean containsSDC = CLALibUtils.containsSDC(groups);
+               double[] constV = containsSDC ? new double[ret.getNumColumns()] 
: null;
+               final List<AColGroup> filteredGroups = containsSDC ? 
CLALibUtils.filterSDCGroups(groups, constV) : groups;
+
+               sortGroups(filteredGroups, overlapping);
+               // check if we are using filtered groups, and if we are not 
force constV to null
+               if(groups == filteredGroups)
+                       constV = null;
+
+               final double eps = getEps(constV);
+               for(int i = 0; i < rlen; i += blklen) {
+                       final int rl = i;
+                       final int ru = Math.min(i + blklen, rlen);
+                       for(AColGroup grp : filteredGroups)
+                               grp.decompressToBlockUnSafe(ret, rl, ru);
+                       if(constV != null && !ret.isInSparseFormat())
+                               addVector(ret, constV, eps, rl, ru);
+               }
+
+               ret.setNonZeros(nonZeros == -1 || overlapping ? 
ret.recomputeNonZeros() : nonZeros);
+
+               return ret;
+       }
+
+       public static MatrixBlock decompress(MatrixBlock ret, List<AColGroup> 
groups, boolean overlapping, int k) {
+
+               try {
+                       final ExecutorService pool = CommonThreadPool.get(k);
+                       final int rlen = ret.getNumRows();
+                       final int block = (int) Math.ceil((double) 
(CompressionSettings.BITMAP_BLOCK_SZ) / ret.getNumColumns());
+                       final int blklen = block > 1000 ? block + 1000 - block 
% 1000 : Math.max(64, block);
+
+                       final boolean containsSDC = 
CLALibUtils.containsSDC(groups);
+                       double[] constV = containsSDC ? new 
double[ret.getNumColumns()] : null;
+                       final List<AColGroup> filteredGroups = containsSDC ? 
CLALibUtils.filterSDCGroups(groups, constV) : groups;
+                       sortGroups(filteredGroups, overlapping);
+
+                       // check if we are using filtered groups, and if we are 
not force constV to null
+                       if(groups == filteredGroups)
+                               constV = null;
+
+                       final double eps = getEps(constV);
+                       final ArrayList<DecompressTask> tasks = new 
ArrayList<>();
+                       for(int i = 0; i * blklen < rlen; i++)
+                               tasks.add(new DecompressTask(filteredGroups, 
ret, eps, i * blklen, Math.min((i + 1) * blklen, rlen),
+                                       overlapping, constV));
+                       List<Future<Long>> rtasks = pool.invokeAll(tasks);
+                       pool.shutdown();
+
+                       long nnz = 0;
+                       for(Future<Long> rt : rtasks)
+                               nnz += rt.get();
+                       ret.setNonZeros(nnz);
+               }
+               catch(InterruptedException | ExecutionException ex) {
+                       throw new DMLCompressionException("Parallel 
decompression failed", ex);
+               }
+
+               return ret;
+       }
+
+       private static void sortGroups(List<AColGroup> groups, boolean 
overlapping) {
+               if(overlapping) {
+                       // add a bit of stability in decompression
+                       Comparator<AColGroup> comp = Comparator.comparing(x -> 
effect(x));
+                       groups.sort(comp);
+               }
+       }
+
+       /**
+        * Calculate an effect value for a column group. This is used to sort 
the groups before decompression to decompress
+        * the columns that have the smallest effect first.
+        * 
+        * @param x A Group
+        * @return A Effect double value.
+        */
+       private static double effect(AColGroup x) {
+               return -Math.max(x.getMax(), Math.abs(x.getMin()));
+       }
+
+       /**
+        * Get a small epsilon from the constant group.
+        * 
+        * @param constV the constant vector.
+        * @return epsilon
+        */
+       private static double getEps(double[] constV) {
+               if(constV == null)
+                       return 0;
+               else {
+                       double max = -Double.MAX_VALUE;
+                       double min = Double.MAX_VALUE;
+                       for(double v : constV){
+                               if(v > max)
+                                       max = v;
+                               if(v < min)
+                                       min = v;
+                       }
+                       final double eps = (max-min) * 1e-13;
+                       return eps;
+               }
+       }
+
+       private static class DecompressTask implements Callable<Long> {
+               private final List<AColGroup> _colGroups;
+               private final MatrixBlock _ret;
+               private final double _eps;
+               private final int _rl;
+               private final int _ru;
+               private final double[] _constV;
+               private final boolean _overlapping;
+
+               protected DecompressTask(List<AColGroup> colGroups, MatrixBlock 
ret, double eps, int rl, int ru,
+                       boolean overlapping, double[] constV) {
+                       _colGroups = colGroups;
+                       _ret = ret;
+                       _eps = eps;
+                       _rl = rl;
+                       _ru = ru;
+                       _overlapping = overlapping;
+                       _constV = constV;
+               }
+
+               @Override
+               public Long call() {
+                       // decompress row partition
+                       for(AColGroup grp : _colGroups)
+                               grp.decompressToBlockUnSafe(_ret, _rl, _ru);
+
+                       if(_constV != null)
+                               addVector(_ret, _constV, _eps, _rl, _ru);
+
+                       return _overlapping ? 0 : _ret.recomputeNonZeros(_rl, 
_ru - 1);
+               }
+       }
+
+       /**
+        * Add the rowV vector to each row in ret.
+        * 
+        * @param ret  matrix to add the vector to
+        * @param rowV The row vector to add
+        * @param eps  an epsilon defined, to round the output value to zero if 
the value is less than epsilon away from
+        *             zero.
+        * @param rl   The row to start at
+        * @param ru   The row to end at
+        */
+       private static void addVector(final MatrixBlock ret, final double[] 
rowV, final double eps, final int rl,
+               final int ru) {
+               final int nCols = ret.getNumColumns();
+               final DenseBlock db = ret.getDenseBlock();
+               if(eps == 0) {
+                       for(int row = rl; row < ru; row++) {
+                               final double[] _retV = db.values(row);
+                               final int off = db.pos(row);
+                               for(int col = 0; col < nCols; col++)
+                                       _retV[off + col] += rowV[col];
+                       }
+               }
+               else {
+                       for(int row = rl; row < ru; row++) {
+                               final double[] _retV = db.values(row);
+                               final int off = db.pos(row);
+                               for(int col = 0; col < nCols; col++) {
+                                       final int out = off + col;
+                                       _retV[out] += rowV[col];
+                                       if(Math.abs(_retV[out]) <= eps)
+                                               _retV[out] = 0;
+                               }
+                       }
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 8b78e58..dc2df68 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -110,11 +110,11 @@ public class CLALibLeftMultBy {
                        multAllColGroups(groups, groups, result);
                }
                else {
-                       final boolean containsSDC = containsSDC(groups);
-                       final int numColumns = cmb.getNumColumns();
+                       final boolean containsSDC = 
CLALibUtils.containsSDC(groups);
                        final double[] constV = containsSDC ? new 
double[cmb.getNumColumns()] : null;
-                       final List<AColGroup> filteredGroups = 
filterSDCGroups(groups, constV);
+                       final List<AColGroup> filteredGroups = 
CLALibUtils.filterSDCGroups(groups, constV);
                        final double[] colSums = containsSDC ? new 
double[cmb.getNumColumns()] : null;
+                       final int numColumns = cmb.getNumColumns();
 
                        if(containsSDC)
                                for(int i = 0; i < groups.size(); i++) {
@@ -298,12 +298,13 @@ public class CLALibLeftMultBy {
                }
 
                final int numColumnsOut = ret.getNumColumns();
-               final boolean containsSDC = containsSDC(colGroups);
+               final boolean containsSDC = CLALibUtils.containsSDC(colGroups);
 
                // a constant colgroup summing the default values.
-               final double[] constV = containsSDC ? new double[numColumnsOut] 
: null;
-               final List<AColGroup> filteredGroups = 
filterSDCGroups(colGroups, constV);
-
+               double[] constV = containsSDC ? new double[numColumnsOut] : 
null;
+               final List<AColGroup> filteredGroups = 
CLALibUtils.filterSDCGroups(colGroups, constV);
+               if(colGroups == filteredGroups)
+                       constV = null;
                final double[] rowSums = containsSDC ? new 
double[that.getNumRows()] : null;
 
                if(k == 1) {
@@ -633,33 +634,4 @@ public class CLALibLeftMultBy {
                Collections.sort(ColGroupValues, 
Comparator.comparing(AColGroup::getNumValues).reversed());
                return ColGroupValues;
        }
-
-       private static boolean containsSDC(List<AColGroup> groups) {
-               boolean containsSDC = false;
-
-               for(AColGroup g : groups) {
-                       if(g instanceof ColGroupSDC || g instanceof 
ColGroupSDCSingle) {
-                               containsSDC = true;
-                               break;
-                       }
-               }
-               return containsSDC;
-       }
-
-       private static List<AColGroup> filterSDCGroups(List<AColGroup> groups, 
double[] constV) {
-               if(constV != null) {
-                       final List<AColGroup> filteredGroups = new 
ArrayList<>();
-                       for(AColGroup g : groups) {
-                               if(g instanceof ColGroupSDC)
-                                       filteredGroups.add(((ColGroupSDC) 
g).extractCommon(constV));
-                               else if(g instanceof ColGroupSDCSingle)
-                                       filteredGroups.add(((ColGroupSDCSingle) 
g).extractCommon(constV));
-                               else
-                                       filteredGroups.add(g);
-                       }
-                       return filteredGroups;
-               }
-               else
-                       return groups;
-       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
index 1822685..ace2843 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
@@ -68,8 +68,7 @@ public class CLALibScalar {
                if(m1.isOverlapping() && !(sop.fn instanceof Multiply || sop.fn 
instanceof Divide)) {
                        AColGroup constOverlap = constOverlap(m1, sop);
                        List<AColGroup> newColGroups = (sop instanceof 
LeftScalarOperator &&
-                               sop.fn instanceof Minus) ? 
processOverlappingSubtractionLeft(m1,
-                                       sop,
+                               sop.fn instanceof Minus) ? 
processOverlappingSubtractionLeft(m1, sop,
                                        ret) : processOverlappingAddition(m1, 
sop, ret);
                        newColGroups.add(constOverlap);
                        ret.allocateColGroupList(newColGroups);
@@ -93,8 +92,8 @@ public class CLALibScalar {
                }
 
                ret.recomputeNonZeros();
-               return ret;
 
+               return ret;
        }
 
        private static CompressedMatrixBlock setupRet(CompressedMatrixBlock m1, 
MatrixValue result) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
new file mode 100644
index 0000000..c701b96
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle;
+
+public class CLALibUtils {
+
+       /**
+        * Helper method to determine if the column groups contains SDC
+        * 
+        * Note that it only returns true, if there is more than one SDC Group.
+        * 
+        * @param groups The ColumnGroups to analyze
+        * @return A Boolean saying it there is >= 2 SDC Groups.
+        */
+       protected static boolean containsSDC(List<AColGroup> groups) {
+               int count = 0;
+               for(AColGroup g : groups) {
+                       if(g instanceof ColGroupSDC || g instanceof 
ColGroupSDCSingle) {
+                               count++;
+                               if(count > 1)
+                                       break;
+                       }
+               }
+               return count > 1;
+       }
+
+       /**
+        * Helper method to filter out SDC Groups, to add their common value to 
the ConstV. This allows exploitation of the
+        * common values in the SDC Groups.
+        * 
+        * @param groups The Column Groups
+        * @param constV The Constant vector to add common values to.
+        * @return The Filtered list of Column groups containing no SDC Groups 
but only SDCZero groups.
+        */
+       protected static List<AColGroup> filterSDCGroups(List<AColGroup> 
groups, double[] constV) {
+               if(constV != null) {
+                       final List<AColGroup> filteredGroups = new 
ArrayList<>();
+                       for(AColGroup g : groups) {
+                               if(g instanceof ColGroupSDC)
+                                       filteredGroups.add(((ColGroupSDC) 
g).extractCommon(constV));
+                               else if(g instanceof ColGroupSDCSingle)
+                                       filteredGroups.add(((ColGroupSDCSingle) 
g).extractCommon(constV));
+                               else
+                                       filteredGroups.add(g);
+                       }
+                       for(double v : constV)
+                               if(!Double.isFinite(v))
+                                       return groups;
+                       
+                       return filteredGroups;
+               }
+               else
+                       return groups;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
 
b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
index e086974..06431db 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
@@ -425,40 +425,48 @@ public class WorkloadAnalyzer {
                                        setDecompressionOnAllInputs(hop, 
parent);
                                        return;
                                }
-                               // shortcut instead of comparing to 
MatrixScalar or RowVector.
-                               else if(hop.getInput(1).getDim1() == 1 || 
hop.getInput(1).isScalar() || hop.getInput(0).isScalar()) {
-
+                               else {
                                        ArrayList<Hop> in = hop.getInput();
                                        final boolean ol0 = 
isOverlapping(in.get(0));
                                        final boolean ol1 = 
isOverlapping(in.get(1));
                                        final boolean ol = ol0 || ol1;
-                                       if(ol && HopRewriteUtils.isBinary(hop, 
OpOp2.PLUS, OpOp2.MULT, OpOp2.DIV, OpOp2.MINUS)) {
-                                               overlapping.add(hop.getHopID());
-                                               o = new OpNormal(hop, true);
-                                               o.setOverlapping();
+
+                                       // shortcut instead of comparing to 
MatrixScalar or RowVector.
+                                       if(in.get(1).getDim1() == 1 || 
in.get(1).isScalar() || in.get(0).isScalar()) {
+
+                                               if(ol && 
HopRewriteUtils.isBinary(hop, OpOp2.PLUS, OpOp2.MULT, OpOp2.DIV, OpOp2.MINUS)) {
+                                                       
overlapping.add(hop.getHopID());
+                                                       o = new OpNormal(hop, 
true);
+                                                       o.setOverlapping();
+                                               }
+                                               else if(ol) {
+                                                       
treeLookup.get(in.get(0).getHopID()).setDecompressing();
+                                                       return;
+                                               }
+                                               else {
+                                                       o = new OpNormal(hop, 
true);
+                                               }
+                                               
if(!HopRewriteUtils.isBinarySparseSafe(hop))
+                                                       o.setDensifying();
+
                                        }
-                                       else if(ol) {
-                                               
treeLookup.get(in.get(0).getHopID()).setDecompressing();
+                                       else 
if(HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) ||
+                                               
HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ||
+                                               
HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) {
+                                               
setDecompressionOnAllInputs(hop, parent);
+                                               return;
+                                       }
+                                       else if(ol0 || ol1){
+                                               
setDecompressionOnAllInputs(hop, parent);
                                                return;
                                        }
                                        else {
-                                               o = new OpNormal(hop, true);
+                                               String ex = "Setting 
decompressed because input Binary Op is unknown, please add the case to 
WorkloadAnalyzer:\n"
+                                                       + Explain.explain(hop);
+                                               LOG.warn(ex);
+                                               
setDecompressionOnAllInputs(hop, parent);
+                                               return;
                                        }
-                                       
if(!HopRewriteUtils.isBinarySparseSafe(hop))
-                                               o.setDensifying();
-
-                               }
-                               else 
if(HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) ||
-                                       
HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ||
-                                       
HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) {
-                                       setDecompressionOnAllInputs(hop, 
parent);
-                                       return;
-                               }
-                               else {
-                                       String ex = "Setting decompressed 
because input Binary Op is unknown, please add the case to WorkloadAnalyzer:\n"
-                                               + Explain.explain(hop);
-                                       LOG.warn(ex);
-                                       setDecompressionOnAllInputs(hop, 
parent);
                                }
 
                        }
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
index f1682a7..ba50a0e 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
@@ -117,7 +117,7 @@ public class WorkloadTest {
                args.put("$3", "0");
 
                // no recompile
-               tests.add(new Object[] {0, 1, 1, 1, 1, 1, 6, 0, true, false, 
"functions/lmDS.dml", args});
+               tests.add(new Object[] {0, 1, 1, 1, 1, 1, 5, 0, true, false, 
"functions/lmDS.dml", args});
                // with recompile
                tests.add(new Object[] {0, 0, 0, 1, 0, 1, 0, 0, true, true, 
"functions/lmDS.dml", args});
                tests.add(new Object[] {0, 0, 0, 1, 10, 10, 1, 0, true, true, 
"functions/lmCG.dml", args});
diff --git 
a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java
 
b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java
index dc5b17f..c5710bd 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java
@@ -188,7 +188,7 @@ public class CompressForce extends CompressBase {
                // be aware that with multiple blocks it is likely that the 
small blocks
                // initially compress, but is to large for overlapping state 
therefor will decompress.
                // In this test it decompress the second small block but keeps 
the first in overlapping state.
-               runTest(1110, 30, 1, 1, ExecType.SPARK, "mmr_sum_plus_2");
+               compressTest(1110, 10, 1.0, ExecType.SPARK, 1, 6, 1, 1, 1, 
"mmr_sum_plus_2");
        }
 
        @Test

Reply via email to