This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit e7f4ef1b066cd76cb8a7166f8aa37a05fba45e24
Author: baunsgaard <baunsga...@tugraz.at>
AuthorDate: Mon Dec 27 17:54:54 2021 +0100

    [SYSTEMDS-3263] CLA MMChain specialization
    
    This commit optimize the MMChain operation to perform better in compressed
    space.
    
    The modifications contained have improved performance for
    census_enc 1000 reps: 154 sec -> 19 sec
    InfiMNIST_1m 100 reps: 100 sec -> 23.5 sec
    
    The improvements are from dedicated kernels for decompression, in DDC,
    Const vector addition, and SDC offset iteration optimizations.
    
    Right Multiplication now have a dedicated fused kernel for multiplication
    and decompression that parallelize the allocation of the output, and
    improve the NNZ count to be part of the decompression of the overlapping
    decompression. Of the above time in census enc decompression was 90 sec
    and is now reduced to 9 sec.
    
    Closes #1491
---
 .../runtime/compress/CompressedMatrixBlock.java    |  38 +----
 .../runtime/compress/colgroup/ColGroupDDC.java     |  17 ++-
 .../compress/colgroup/ColGroupSDCZeros.java        | 111 +++++++++++---
 .../compress/colgroup/offset/AIterator.java        |  28 ++--
 .../runtime/compress/colgroup/offset/AOffset.java  |  38 ++---
 .../compress/colgroup/offset/OffsetByte.java       | 168 ++++++++++++++++-----
 .../compress/colgroup/offset/OffsetChar.java       | 122 +++++++--------
 .../runtime/compress/lib/CLALibDecompress.java     | 125 +++++++++++----
 .../runtime/compress/lib/CLALibLeftMultBy.java     |   6 +-
 .../sysds/runtime/compress/lib/CLALibMMChain.java  | 112 ++++++++++++++
 .../runtime/compress/lib/CLALibRightMultBy.java    | 111 ++++++++++----
 src/test/java/org/apache/sysds/test/TestUtils.java |   2 +-
 .../compress/offset/OffsetTestPreAggregate.java    |  72 ++++++---
 13 files changed, 670 insertions(+), 280 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index 38af860..85cc23b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -36,8 +36,6 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.commons.math3.random.Well1024a;
 import org.apache.sysds.common.Types.CorrectionLocationType;
-import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.lops.MMTSJ.MMTSJType;
 import org.apache.sysds.lops.MapMultChain.ChainType;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -52,6 +50,7 @@ import 
org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
 import org.apache.sysds.runtime.compress.lib.CLALibCompAgg;
 import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
 import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
+import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
 import org.apache.sysds.runtime.compress.lib.CLALibReExpand;
 import org.apache.sysds.runtime.compress.lib.CLALibRightMultBy;
 import org.apache.sysds.runtime.compress.lib.CLALibScalar;
@@ -67,7 +66,6 @@ import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseRow;
 import org.apache.sysds.runtime.functionobjects.MinusMultiply;
-import org.apache.sysds.runtime.functionobjects.Multiply;
 import org.apache.sysds.runtime.functionobjects.PlusMultiply;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import 
org.apache.sysds.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
@@ -77,7 +75,6 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 import org.apache.sysds.runtime.matrix.data.CTableMap;
 import org.apache.sysds.runtime.matrix.data.IJV;
-import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
 import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
 import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.LibMatrixTercell;
@@ -470,38 +467,7 @@ public class CompressedMatrixBlock extends MatrixBlock {
                        _colGroups.get(0).getCompType() == 
CompressionType.UNCOMPRESSED)
                        return ((ColGroupUncompressed) 
_colGroups.get(0)).getData().chainMatrixMultOperations(v, w, out, ctype, k);
 
-               // prepare result
-               if(out != null)
-                       out.reset(clen, 1, false);
-               else
-                       out = new MatrixBlock(clen, 1, false);
-
-               // empty block handling
-               if(isEmpty())
-                       return out;
-
-               BinaryOperator bop = new 
BinaryOperator(Multiply.getMultiplyFnObject(), k);
-               boolean allowOverlap = 
ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_OVERLAPPING)
 &&
-                       v.getNumColumns() > 1;
-               MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(this, v, 
null, k, allowOverlap);
-
-               if(ctype == ChainType.XtwXv) {
-                       if(tmp instanceof CompressedMatrixBlock)
-                               tmp = 
CLALibBinaryCellOp.binaryOperationsRight(bop, (CompressedMatrixBlock) tmp, w, 
null);
-                       else
-                               LibMatrixBincell.bincellOpInPlace(tmp, w, bop);
-               }
-
-               if(tmp instanceof CompressedMatrixBlock)
-                       CLALibLeftMultBy.leftMultByMatrixTransposed(this, 
(CompressedMatrixBlock) tmp, out, k);
-               else
-                       CLALibLeftMultBy.leftMultByMatrixTransposed(this, tmp, 
out, k);
-
-               if(out.getNumColumns() != 1)
-                       out = LibMatrixReorg.transposeInPlace(out, k);
-
-               out.recomputeNonZeros();
-               return out;
+               return CLALibMMChain.mmChain(this, v, w, out, ctype, k);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
index 8b9888f..b8a0b88 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
@@ -72,8 +72,14 @@ public class ColGroupDDC extends APreAgg {
        @Override
        protected void decompressToDenseBlockDenseDictionary(DenseBlock db, int 
rl, int ru, int offR, int offC,
                double[] values) {
-               if(db.isContiguous() && _colIndexes.length == 1)
-                       decompressToDenseBlockDenseDictSingleColContiguous(db, 
rl, ru, offR, offC, values);
+               if(db.isContiguous() && _colIndexes.length == 1) {
+
+                       if(db.getDim(1) == 1)
+                               
decompressToDenseBlockDenseDictSingleColOutContiguous(db, rl, ru, offR, offC, 
values);
+                       else
+                               
decompressToDenseBlockDenseDictSingleColContiguous(db, rl, ru, offR, offC, 
values);
+
+               }
                else {
                        // generic
                        final int nCol = _colIndexes.length;
@@ -97,6 +103,13 @@ public class ColGroupDDC extends APreAgg {
 
        }
 
+       private void 
decompressToDenseBlockDenseDictSingleColOutContiguous(DenseBlock db, int rl, 
int ru, int offR, int offC,
+               double[] values) {
+               final double[] c = db.values(0);
+               for(int i = rl, offT = rl + offR + _colIndexes[0] + offC; i < 
ru; i++, offT++)
+                       c[offT] += values[_data.getIndex(i)];
+       }
+
        @Override
        protected void decompressToSparseBlockSparseDictionary(SparseBlock ret, 
int rl, int ru, int offR, int offC,
                SparseBlock sb) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
index 7e39b74..f1dcf2e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
@@ -100,37 +100,100 @@ public class ColGroupSDCZeros extends APreAgg {
                        return;
                if(it.value() >= ru)
                        _indexes.cacheIterator(it, ru);
-               else if(ru > _indexes.getOffsetToLast()) {
-                       final int lastOff = _indexes.getOffsetToLast();
-                       final int nCol = _colIndexes.length;
-                       while(true) {
-                               final int idx = offR + it.value();
-                               final double[] c = db.values(idx);
-                               final int off = db.pos(idx) + offC;
-                               final int offDict = 
_data.getIndex(it.getDataIndex()) * nCol;
-                               for(int j = 0; j < nCol; j++)
-                                       c[off + _colIndexes[j]] += 
values[offDict + j];
-                               if(it.value() == lastOff)
-                                       return;
-                               it.next();
+               else if(db.isContiguous() && _colIndexes.length == 1) {
+                       if(ru > _indexes.getOffsetToLast())
+                               
decompressToDenseBlockDenseDictionaryPostSingleColContiguous(db, rl, ru, offR, 
offC, values, it);
+                       else {
+                               if(db.getDim(1) == 1)
+                                       
decompressToDenseBlockDenseDictionaryPreSingleColOutContiguous(db, ru, offR, 
offC, values, it, _data);
+                               else
+                                       
decompressToDenseBlockDenseDictionaryPreSingleColContiguous(db, rl, ru, offR, 
offC, values, it);
+                               _indexes.cacheIterator(it, ru);
                        }
                }
+               else if(ru > _indexes.getOffsetToLast())
+                       decompressToDenseBlockDenseDictionaryPostGeneric(db, 
rl, ru, offR, offC, values, it);
                else {
+                       decompressToDenseBlockDenseDictionaryPreGeneric(db, rl, 
ru, offR, offC, values, it);
+                       _indexes.cacheIterator(it, ru);
+               }
+       }
 
-                       final int nCol = _colIndexes.length;
-                       while(it.isNotOver(ru)) {
-                               final int idx = offR + it.value();
-                               final double[] c = db.values(idx);
-                               final int off = db.pos(idx) + offC;
-                               final int offDict = 
_data.getIndex(it.getDataIndex()) * nCol;
-                               for(int j = 0; j < nCol; j++)
-                                       c[off + _colIndexes[j]] += 
values[offDict + j];
+       private void 
decompressToDenseBlockDenseDictionaryPostSingleColContiguous(DenseBlock db, int 
rl, int ru, int offR,
+               int offC, double[] values, AIterator it) {
+               final int lastOff = _indexes.getOffsetToLast() + offR;
+               final int nCol = db.getDim(1);
+               final double[] c = db.values(0);
+               it.setOff(it.value() + offR);
+               offC += _colIndexes[0];
+               while(it.value() < lastOff) {
+                       final int off = it.value() * nCol + offC;
+                       c[off] += values[_data.getIndex(it.getDataIndex())];
+                       it.next();
+               }
+               final int off = it.value() * nCol + offC;
+               c[off] += values[_data.getIndex(it.getDataIndex())];
+               it.setOff(it.value() - offR);
+       }
 
-                               it.next();
-                       }
-                       _indexes.cacheIterator(it, ru);
+       private void 
decompressToDenseBlockDenseDictionaryPostGeneric(DenseBlock db, int rl, int ru, 
int offR, int offC,
+               double[] values, AIterator it) {
+               final int lastOff = _indexes.getOffsetToLast();
+               final int nCol = _colIndexes.length;
+               while(true) {
+                       final int idx = offR + it.value();
+                       final double[] c = db.values(idx);
+                       final int off = db.pos(idx) + offC;
+                       final int offDict = _data.getIndex(it.getDataIndex()) * 
nCol;
+                       for(int j = 0; j < nCol; j++)
+                               c[off + _colIndexes[j]] += values[offDict + j];
+                       if(it.value() == lastOff)
+                               return;
+                       it.next();
                }
+       }
 
+       private static void 
decompressToDenseBlockDenseDictionaryPreSingleColOutContiguous(DenseBlock db, 
int ru, int offR,
+               int offC, double[] values, AIterator it, AMapToData m) {
+               final double[] c = db.values(0);
+               final int of = offR + offC;
+               final int last = ru + of;
+               it.setOff(it.value() + of);
+               while(it.isNotOver(last)) {
+                       c[it.value()] += values[m.getIndex(it.getDataIndex())];
+                       it.next();
+               }
+               it.setOff(it.value() - of);
+       }
+
+       private void 
decompressToDenseBlockDenseDictionaryPreSingleColContiguous(DenseBlock db, int 
rl, int ru, int offR,
+               int offC, double[] values, AIterator it) {
+               final int last = ru + offR;
+               final int nCol = db.getDim(1);
+               final double[] c = db.values(0);
+               it.setOff(it.value() + offR);
+               offC += _colIndexes[0];
+               while(it.isNotOver(last)) {
+                       final int off = it.value() * nCol + offC;
+                       c[off] += values[_data.getIndex(it.getDataIndex())];
+                       it.next();
+               }
+               it.setOff(it.value() - offR);
+       }
+
+       private void decompressToDenseBlockDenseDictionaryPreGeneric(DenseBlock 
db, int rl, int ru, int offR, int offC,
+               double[] values, AIterator it) {
+               final int nCol = _colIndexes.length;
+               while(it.isNotOver(ru)) {
+                       final int idx = offR + it.value();
+                       final double[] c = db.values(idx);
+                       final int off = db.pos(idx) + offC;
+                       final int offDict = _data.getIndex(it.getDataIndex()) * 
nCol;
+                       for(int j = 0; j < nCol; j++)
+                               c[off + _colIndexes[j]] += values[offDict + j];
+
+                       it.next();
+               }
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java
index bb4d13e..e747e9e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java
@@ -25,22 +25,16 @@ import org.apache.commons.logging.LogFactory;
  * Iterator interface, that returns a iterator of the indexes (not offsets)
  */
 public abstract class AIterator {
-       protected static final Log LOG = 
LogFactory.getLog(AIterator.class.getName());
+       public static final Log LOG = 
LogFactory.getLog(AIterator.class.getName());
 
-       protected int index;
-       protected int dataIndex;
        protected int offset;
 
        /**
         * Main Constructor
         * 
-        * @param index     The current index that correspond to an actual 
value in the dictionary.
-        * @param dataIndex The current index int the offset.
-        * @param offset    The current index in the uncompressed 
representation.
+        * @param offset The current offset into in the uncompressed 
representation.
         */
-       protected AIterator(int index, int dataIndex, int offset) {
-               this.index = index;
-               this.dataIndex = dataIndex;
+       protected AIterator(int offset) {
                this.offset = offset;
        }
 
@@ -58,8 +52,12 @@ public abstract class AIterator {
                return offset;
        }
 
+       public void setOff(int off){
+               offset = off;
+       }
+
        /**
-        * find out if the current offset is not exceeding the index.
+        * Find out if the current offset is not exceeding the index given.
         * 
         * @param ub The offset to not exceed
         * @return boolean if it is exceeded.
@@ -76,9 +74,7 @@ public abstract class AIterator {
         * 
         * @return The Data Index.
         */
-       public int getDataIndex() {
-               return dataIndex;
-       }
+       public abstract int getDataIndex();
 
        /**
         * Get the current offsets index, that points to the underlying offsets 
list.
@@ -87,9 +83,7 @@ public abstract class AIterator {
         * 
         * @return The Offsets Index.
         */
-       public int getOffsetsIndex() {
-               return index;
-       }
+       public abstract int getOffsetsIndex();
 
        /**
         * Skip values until index is achieved.
@@ -111,6 +105,6 @@ public abstract class AIterator {
         * @return The result
         */
        public boolean equals(AIterator o) {
-               return o.index == this.index;
+               return o.getOffsetsIndex() == getOffsetsIndex();
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
index 328dcb5..3e8ecf1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
@@ -76,24 +76,16 @@ public abstract class AOffset implements Serializable {
                else if(row > getOffsetToLast())
                        return null;
 
-               // try the cache first.
+               // Try the cache first.
                OffsetCache c = cacheRow.get();
-               if(c == null) {
-                       if(memorizer != null && memorizer.containsKey(row))
-                               return memorizer.get(row).clone();
-                       AIterator it = getIterator();
-                       it.skipTo(row);
-                       cacheIterator(it.clone(), row);
-                       memorizeIterator(it.clone(), row);
-                       return it;
-               }
-               else if(c.row == row)
+
+               if(c != null && c.row == row)
                        return c.it.clone();
                else {
                        if(memorizer != null && memorizer.containsKey(row))
                                return memorizer.get(row).clone();
                        // Use the cached iterator if it is closer to the 
queried row.
-                       AIterator it = c.row < row ? c.it.clone() : 
getIterator();
+                       AIterator it = c != null && c.row < row ? c.it.clone() 
: getIterator();
                        it.skipTo(row);
                        // cache this new iterator.
                        cacheIterator(it.clone(), row);
@@ -264,7 +256,7 @@ public abstract class AOffset implements Serializable {
                final double[] vals = db.values(rl);
                final int nCol = db.getCumODims(0);
                while(it.offset < cu) {
-                       final int dataOffset = data.get(it.dataIndex) ? 1 : 0;
+                       final int dataOffset = data.get(it.getDataIndex()) ? 1 
: 0;
                        final int start = it.offset + nCol * rl;
                        final int end = it.offset + nCol * ru;
                        for(int offOut = dataOffset, off = start; off < end; 
offOut += nVal, off += nCol)
@@ -280,14 +272,14 @@ public abstract class AOffset implements Serializable {
                final double[] vals = db.values(rl);
                final int nCol = db.getCumODims(0);
                final int last = getOffsetToLast();
-               int dataOffset = data.get(it.dataIndex) ? 1 : 0;
+               int dataOffset = data.get(it.getDataIndex()) ? 1 : 0;
                int start = it.offset + nCol * rl;
                int end = it.offset + nCol * ru;
                for(int offOut = dataOffset, off = start; off < end; offOut += 
nVal, off += nCol)
                        preAV[offOut] += vals[off];
                while(it.offset < last) {
                        it.next();
-                       dataOffset = data.get(it.dataIndex) ? 1 : 0;
+                       dataOffset = data.get(it.getDataIndex()) ? 1 : 0;
                        start = it.offset + nCol * rl;
                        end = it.offset + nCol * ru;
                        for(int offOut = dataOffset, off = start; off < end; 
offOut += nVal, off += nCol)
@@ -330,8 +322,8 @@ public abstract class AOffset implements Serializable {
                int j = apos;
                while(j < alen) {
                        if(aix[j] == it.offset) {
-                               preAV[data[it.dataIndex] & 0xFF] += avals[j++];
-                               if(it.dataIndex >= maxId)
+                               preAV[data[it.getDataIndex()] & 0xFF] += 
avals[j++];
+                               if(it.getDataIndex() >= maxId)
                                        break;
                                it.next();
                        }
@@ -339,7 +331,7 @@ public abstract class AOffset implements Serializable {
                                j++;
                        }
                        else {
-                               if(it.dataIndex >= maxId)
+                               if(it.getDataIndex() >= maxId)
                                        break;
                                it.next();
                        }
@@ -356,8 +348,8 @@ public abstract class AOffset implements Serializable {
                int j = apos;
                while(j < alen) {
                        if(aix[j] == it.offset) {
-                               preAV[data[it.dataIndex]] += avals[j++];
-                               if(it.dataIndex >= maxId)
+                               preAV[data[it.getDataIndex()]] += avals[j++];
+                               if(it.getDataIndex() >= maxId)
                                        break;
                                it.next();
                        }
@@ -365,7 +357,7 @@ public abstract class AOffset implements Serializable {
                                j++;
                        }
                        else {
-                               if(it.dataIndex >= maxId)
+                               if(it.getDataIndex() >= maxId)
                                        break;
                                it.next();
                        }
@@ -382,7 +374,7 @@ public abstract class AOffset implements Serializable {
                int j = apos;
                while(it.offset < last && j < alen) {
                        if(aix[j] == it.offset) {
-                               preAV[data.get(it.dataIndex) ? 1 : 0] += 
avals[j++];
+                               preAV[data.get(it.getDataIndex()) ? 1 : 0] += 
avals[j++];
                                it.next();
                        }
                        if(j < alen)
@@ -394,7 +386,7 @@ public abstract class AOffset implements Serializable {
                while(j < alen && aix[j] < it.offset)
                        j++;
                if(j != alen && aix[j] == it.offset)
-                       preAV[data.get(it.dataIndex) ? 1 : 0] += avals[j];
+                       preAV[data.get(it.getDataIndex()) ? 1 : 0] += avals[j];
 
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
index 4654cdb..7fb7b2b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetByte.java
@@ -108,7 +108,12 @@ public class OffsetByte extends AOffset {
 
        @Override
        public IterateByteOffset getIterator() {
-               return new IterateByteOffset();
+               if(noOverHalf)
+                       return new IterateByteOffsetNoOverHalf();
+               else if(noZero)
+                       return new IterateByteOffsetNoZero();
+               else
+                       return new IterateByteOffset();
        }
 
        @Override
@@ -208,8 +213,8 @@ public class OffsetByte extends AOffset {
                final int maxId = data.length - 1;
 
                int offset = it.offset + off;
-               int index = it.index;
-               int dataIndex = it.dataIndex;
+               int index = it.getOffsetsIndex();
+               int dataIndex = it.getDataIndex();
 
                preAV[data[dataIndex] & 0xFF] += mV[offset];
                while(dataIndex < maxId) {
@@ -230,7 +235,7 @@ public class OffsetByte extends AOffset {
                IterateByteOffset it) {
 
                int offset = it.offset + off;
-               int index = it.index;
+               int index = it.getOffsetsIndex();
 
                while(index < offsets.length) {
                        preAV[data[index] & 0xFF] += mV[offset];
@@ -246,7 +251,7 @@ public class OffsetByte extends AOffset {
                cu += off;
                it.offset += off;
                while(it.offset < cu) {
-                       preAV[data[it.dataIndex] & 0xFF] += mV[it.offset];
+                       preAV[data[it.getDataIndex()] & 0xFF] += mV[it.offset];
                        byte v = offsets[it.index];
                        while(v == 0) {
                                it.offset += maxV;
@@ -266,7 +271,7 @@ public class OffsetByte extends AOffset {
                cu += off;
                it.offset += off;
                while(it.offset < cu) {
-                       preAV[data[it.dataIndex]] += mV[it.offset];
+                       preAV[data[it.getDataIndex()]] += mV[it.offset];
                        byte v = offsets[it.index];
                        while(v == 0) {
                                it.offset += maxV;
@@ -286,7 +291,7 @@ public class OffsetByte extends AOffset {
                cu += off;
                it.offset += off;
                while(it.offset < cu) {
-                       preAV[data[it.dataIndex] & 0xFF] += mV[it.offset];
+                       preAV[data[it.getDataIndex()] & 0xFF] += mV[it.offset];
                        byte v = offsets[it.index];
                        while(v == 0) {
                                it.offset += maxV;
@@ -304,7 +309,7 @@ public class OffsetByte extends AOffset {
                IterateByteOffset it) {
 
                int offset = it.offset + off;
-               int index = it.index;
+               int index = it.getOffsetsIndex();
 
                cu += off;
 
@@ -321,7 +326,7 @@ public class OffsetByte extends AOffset {
        private final void 
preAggregateDenseByteMapRowBelowEndAndNoZeroNoOverHalf(double[] mV, int off, 
double[] preAV,
                int cu, byte[] data, IterateByteOffset it) {
                int offset = it.offset + off;
-               int index = it.index;
+               int index = it.getOffsetsIndex();
 
                cu += off;
 
@@ -338,7 +343,7 @@ public class OffsetByte extends AOffset {
        private final void 
preAggregateDenseByteMapRowBelowEndAndNoZeroNoOverHalfAlsoData(double[] mV, int 
off,
                double[] preAV, int cu, byte[] data, IterateByteOffset it) {
                int offset = it.offset + off;
-               int index = it.index;
+               int index = it.getOffsetsIndex();
 
                cu += off;
 
@@ -375,8 +380,8 @@ public class OffsetByte extends AOffset {
        private void preAggregateDenseCharMapRow(double[] mV, int off, double[] 
preAV, char[] data, IterateByteOffset it) {
                final int maxId = data.length - 1;
                int offset = it.offset + off;
-               int index = it.index;
-               int dataIndex = it.dataIndex;
+               int index = it.getOffsetsIndex();
+               int dataIndex = it.getDataIndex();
 
                preAV[data[dataIndex]] += mV[offset];
                while(dataIndex < maxId) {
@@ -397,7 +402,7 @@ public class OffsetByte extends AOffset {
                IterateByteOffset it) {
 
                int offset = it.offset + off;
-               int index = it.index;
+               int index = it.getOffsetsIndex();
                while(index < offsets.length) {
                        preAV[data[index]] += mV[offset];
                        offset += offsets[index++] & 0xFF;
@@ -411,7 +416,7 @@ public class OffsetByte extends AOffset {
                cu += off;
                it.offset += off;
                while(it.offset < cu) {
-                       preAV[data[it.dataIndex]] += mV[it.offset];
+                       preAV[data[it.getDataIndex()]] += mV[it.offset];
                        byte v = offsets[it.index];
                        while(v == 0) {
                                it.offset += maxV;
@@ -428,7 +433,7 @@ public class OffsetByte extends AOffset {
        private void preAggregateDenseCharMapRowBelowEndAndNoZero(double[] mV, 
int off, double[] preAV, int cu, char[] data,
                IterateByteOffset it) {
                int offset = it.offset + off;
-               int index = it.index;
+               int index = it.getOffsetsIndex();
 
                cu += off;
 
@@ -445,7 +450,7 @@ public class OffsetByte extends AOffset {
        private final void 
preAggregateDenseCharMapRowBelowEndAndNoZeroNoOverHalf(double[] mV, int off, 
double[] preAV,
                int cu, char[] data, IterateByteOffset it) {
                int offset = it.offset + off;
-               int index = it.index;
+               int index = it.getOffsetsIndex();
 
                cu += off;
 
@@ -462,9 +467,10 @@ public class OffsetByte extends AOffset {
        @Override
        protected final void preAggregateDenseMapRowBit(double[] mV, int off, 
double[] preAV, int cu, int nVal, BitSet data,
                AIterator it) {
-               int offset = it.offset + off;
-               int index = it.index;
-               int dataIndex = it.dataIndex;
+               IterateByteOffset itb = (IterateByteOffset) it;
+               int offset = itb.offset + off;
+               int index = itb.getOffsetsIndex();
+               int dataIndex = itb.getDataIndex();
 
                if(cu > offsetToLast) {
                        final int last = offsetToLast + off;
@@ -499,9 +505,9 @@ public class OffsetByte extends AOffset {
                        }
 
                }
-               it.offset = offset - off;
-               it.dataIndex = index;
-               it.index = index;
+               itb.offset = offset - off;
+               itb.dataIndex = index;
+               itb.index = index;
                cacheIterator(it, cu);
        }
 
@@ -520,7 +526,7 @@ public class OffsetByte extends AOffset {
                final double[] vals = db.values(rl);
                final int nCol = db.getCumODims(0);
                while(it.offset < cu) {
-                       final int dataOffset = data[it.dataIndex] & 0xFF;
+                       final int dataOffset = data[it.getDataIndex()] & 0xFF;
                        final int start = it.offset + nCol * rl;
                        final int end = it.offset + nCol * ru;
                        for(int offOut = dataOffset, off = start; off < end; 
offOut += nVal, off += nCol)
@@ -535,8 +541,8 @@ public class OffsetByte extends AOffset {
                byte[] data, IterateByteOffset it) {
                final int maxId = data.length - 1;
                final int offsetStart = it.offset;
-               final int indexStart = it.index;
-               final int dataIndexStart = it.dataIndex;
+               final int indexStart = it.getOffsetsIndex();
+               final int dataIndexStart = it.getDataIndex();
                // all the way to the end of offsets.
                for(int r = rl; r < ru; r++) {
                        final int offOut = (r - rl) * nVal;
@@ -545,10 +551,10 @@ public class OffsetByte extends AOffset {
                        it.offset = offsetStart + off;
                        it.index = indexStart;
                        it.dataIndex = dataIndexStart;
-                       preAV[offOut + data[it.dataIndex] & 0xFF] += 
vals[it.offset];
-                       while(it.dataIndex < maxId) {
+                       preAV[offOut + data[it.getDataIndex()] & 0xFF] += 
vals[it.offset];
+                       while(it.getDataIndex() < maxId) {
                                it.next();
-                               preAV[offOut + data[it.dataIndex] & 0xFF] += 
vals[it.offset];
+                               preAV[offOut + data[it.getDataIndex()] & 0xFF] 
+= vals[it.offset];
                        }
                }
        }
@@ -568,7 +574,7 @@ public class OffsetByte extends AOffset {
                int nVal, char[] data, IterateByteOffset it) {
                final double[] vals = db.values(rl);
                while(it.offset < cu) {
-                       final int dataOffset = data[it.dataIndex];
+                       final int dataOffset = data[it.getDataIndex()];
                        for(int r = rl, offOut = dataOffset; r < ru; r++, 
offOut += nVal)
                                preAV[offOut] += vals[it.offset + db.pos(r)];
                        it.next();
@@ -581,8 +587,8 @@ public class OffsetByte extends AOffset {
                final int maxId = data.length - 1;
                // all the way to the end.
                final int offsetStart = it.offset;
-               final int indexStart = it.index;
-               final int dataIndexStart = it.dataIndex;
+               final int indexStart = it.getOffsetsIndex();
+               final int dataIndexStart = it.getDataIndex();
                for(int r = rl; r < ru; r++) {
                        final int offOut = (r - rl) * nVal;
                        final int off = db.pos(r);
@@ -590,22 +596,29 @@ public class OffsetByte extends AOffset {
                        it.offset = offsetStart + off;
                        it.index = indexStart;
                        it.dataIndex = dataIndexStart;
-                       preAV[offOut + data[it.dataIndex]] += vals[it.offset];
-                       while(it.dataIndex < maxId) {
+                       preAV[offOut + data[it.getDataIndex()]] += 
vals[it.offset];
+                       while(it.getDataIndex() < maxId) {
                                it.next();
-                               preAV[offOut + data[it.dataIndex]] += 
vals[it.offset];
+                               preAV[offOut + data[it.getDataIndex()]] += 
vals[it.offset];
                        }
                }
        }
 
        private class IterateByteOffset extends AIterator {
 
+               protected int index;
+               protected int dataIndex;
+
                private IterateByteOffset() {
-                       super(0, 0, offsetToFirst);
+                       super(offsetToFirst);
+                       index = 0;
+                       dataIndex = 0;
                }
 
                private IterateByteOffset(int index, int dataIndex, int offset) 
{
-                       super(index, dataIndex, offset);
+                       super(offset);
+                       this.index = index;
+                       this.dataIndex = dataIndex;
                }
 
                @Override
@@ -625,8 +638,7 @@ public class OffsetByte extends AOffset {
                public int skipTo(int idx) {
                        if(noOverHalf) {
                                while(offset < idx && index < offsets.length) {
-                                       byte v = offsets[index];
-                                       offset += v;
+                                       offset += offsets[index];
                                        index++;
                                }
                                dataIndex = index;
@@ -644,6 +656,86 @@ public class OffsetByte extends AOffset {
                public IterateByteOffset clone() {
                        return new IterateByteOffset(index, dataIndex, offset);
                }
+
+               @Override
+               public int getDataIndex() {
+                       return dataIndex;
+               }
+
+               @Override
+               public int getOffsetsIndex() {
+                       return index;
+               }
        }
 
+       private class IterateByteOffsetNoZero extends IterateByteOffset {
+
+               private IterateByteOffsetNoZero() {
+                       super();
+               }
+
+               private IterateByteOffsetNoZero(int index, int dataIndex, int 
offset) {
+                       super(index, dataIndex, offset);
+               }
+
+               @Override
+               public void next() {
+                       byte v = offsets[index];
+                       offset += v & 0xFF;
+                       index++;
+                       dataIndex++;
+               }
+
+               @Override
+               public int skipTo(int idx) {
+                       while(offset < idx && index < offsets.length) {
+                               int v = offsets[index] & 0xFF;
+                               offset += v;
+                               index++;
+                       }
+                       dataIndex = index;
+
+                       return offset;
+               }
+
+               @Override
+               public IterateByteOffsetNoZero clone() {
+                       return new IterateByteOffsetNoZero(index, dataIndex, 
offset);
+               }
+
+       }
+
+       private class IterateByteOffsetNoOverHalf extends IterateByteOffset {
+
+               private IterateByteOffsetNoOverHalf() {
+                       super();
+               }
+
+               private IterateByteOffsetNoOverHalf(int index, int dataIndex, 
int offset) {
+                       super(index, dataIndex, offset);
+               }
+
+               @Override
+               public void next() {
+                       offset += offsets[index];
+                       index++;
+                       dataIndex++;
+               }
+
+               @Override
+               public int skipTo(int idx) {
+                       while(offset < idx && index < offsets.length) {
+                               offset += offsets[index];
+                               index++;
+                       }
+                       dataIndex = index;
+
+                       return offset;
+               }
+
+               @Override
+               public IterateByteOffsetNoOverHalf clone() {
+                       return new IterateByteOffsetNoOverHalf(index, 
dataIndex, offset);
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetChar.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetChar.java
index 695d6c5..f633c53 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetChar.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetChar.java
@@ -179,46 +179,26 @@ public class OffsetChar extends AOffset {
        @Override
        protected final void preAggregateDenseMapRowBit(double[] mV, int off, 
double[] preAV, int cu, int nVal, BitSet data,
                AIterator it) {
-               int offset = it.offset + off;
-               int index = it.index;
-               int dataIndex = it.dataIndex;
+               it.offset += off;
 
                if(cu > offsetToLast) {
                        final int last = offsetToLast + off;
 
-                       while(offset < last) {
-                               preAV[data.get(dataIndex) ? 1 : 0] += 
mV[offset];
-                               char v = offsets[index];
-                               while(v == 0) {
-                                       offset += maxV;
-                                       index++;
-                                       v = offsets[index];
-                               }
-                               offset += v;
-                               index++;
-                               dataIndex++;
+                       while(it.offset < last) {
+                               preAV[data.get(it.getDataIndex()) ? 1 : 0] += 
mV[it.offset];
+                               it.next();
+
                        }
-                       preAV[data.get(dataIndex) ? 1 : 0] += mV[offset];
+                       preAV[data.get(it.getDataIndex()) ? 1 : 0] += 
mV[it.offset];
                }
                else {
                        final int last = cu + off;
-                       while(offset < last) {
-                               preAV[data.get(dataIndex) ? 1 : 0] += 
mV[offset];
-                               char v = offsets[index];
-                               while(v == 0) {
-                                       offset += maxV;
-                                       index++;
-                                       v = offsets[index];
-                               }
-                               offset += v;
-                               index++;
-                               dataIndex++;
+                       while(it.offset < last) {
+                               preAV[data.get(it.getDataIndex()) ? 1 : 0] += 
mV[it.offset];
+                               it.next();
                        }
-
                }
-               it.offset = offset - off;
-               it.dataIndex = index;
-               it.index = index;
+               it.offset -= off;
                cacheIterator(it, cu);
        }
 
@@ -226,9 +206,7 @@ public class OffsetChar extends AOffset {
        protected void preAggregateDenseMapRowsByte(DenseBlock db, double[] 
preAV, int rl, int ru, int cl, int cu, int nVal,
                byte[] data, AIterator it) {
 
-               final int offsetStart = it.offset;
-               final int indexStart = it.index;
-               final int dataIndexStart = it.dataIndex;
+               final AIterator sIt = it.clone();
                if(cu < getOffsetToLast() + 1) {
                        // inside offsets
                        for(int r = rl; r < ru; r++) {
@@ -236,11 +214,10 @@ public class OffsetChar extends AOffset {
                                final double[] vals = db.values(r);
                                final int off = db.pos(r);
                                final int cur = cu + off;
-                               it.offset = offsetStart + off;
-                               it.index = indexStart;
-                               it.dataIndex = dataIndexStart;
+                               it = sIt.clone();
+                               it.offset += off;
                                while(it.offset < cur) {
-                                       preAV[offOut + data[it.dataIndex] & 
0xFF] += vals[it.offset];
+                                       preAV[offOut + data[it.getDataIndex()] 
& 0xFF] += vals[it.offset];
                                        it.next();
                                }
                                it.offset -= off;
@@ -254,13 +231,12 @@ public class OffsetChar extends AOffset {
                                final int offOut = (r - rl) * nVal;
                                final int off = db.pos(r);
                                final double[] vals = db.values(r);
-                               it.offset = offsetStart + off;
-                               it.index = indexStart;
-                               it.dataIndex = dataIndexStart;
-                               preAV[offOut + data[it.dataIndex] & 0xFF] += 
vals[it.offset];
-                               while(it.dataIndex < maxId) {
+                               it = sIt.clone();
+                               it.offset = it.offset + off;
+                               preAV[offOut + data[it.getDataIndex()] & 0xFF] 
+= vals[it.offset];
+                               while(it.getDataIndex() < maxId) {
                                        it.next();
-                                       preAV[offOut + data[it.dataIndex] & 
0xFF] += vals[it.offset];
+                                       preAV[offOut + data[it.getDataIndex()] 
& 0xFF] += vals[it.offset];
                                }
                        }
                }
@@ -269,10 +245,10 @@ public class OffsetChar extends AOffset {
        @Override
        protected void preAggregateDenseMapRowsChar(DenseBlock db, double[] 
preAV, int rl, int ru, int cl, int cu, int nVal,
                char[] data, AIterator it) {
-
-               final int offsetStart = it.offset;
-               final int indexStart = it.index;
-               final int dataIndexStart = it.dataIndex;
+               final IterateCharOffset itb = (IterateCharOffset) it;
+               final int offsetStart = itb.offset;
+               final int indexStart = itb.index;
+               final int dataIndexStart = itb.dataIndex;
                if(cu < getOffsetToLast() + 1) {
 
                        for(int r = rl; r < ru; r++) {
@@ -280,17 +256,17 @@ public class OffsetChar extends AOffset {
                                final double[] vals = db.values(r);
                                final int off = db.pos(r);
                                final int cur = cu + off;
-                               it.offset = offsetStart + off;
-                               it.index = indexStart;
-                               it.dataIndex = dataIndexStart;
-                               while(it.offset < cur) {
-                                       preAV[offOut + data[it.dataIndex]] += 
vals[it.offset];
-                                       it.next();
+                               itb.offset = offsetStart + off;
+                               itb.index = indexStart;
+                               itb.dataIndex = dataIndexStart;
+                               while(itb.offset < cur) {
+                                       preAV[offOut + data[itb.dataIndex]] += 
vals[itb.offset];
+                                       itb.next();
                                }
-                               it.offset -= off;
+                               itb.offset -= off;
                        }
 
-                       cacheIterator(it, cu);
+                       cacheIterator(itb, cu);
                }
                else {
                        final int maxId = data.length - 1;
@@ -299,13 +275,13 @@ public class OffsetChar extends AOffset {
                                final int offOut = (r - rl) * nVal;
                                final int off = db.pos(r);
                                final double[] vals = db.values(r);
-                               it.offset = offsetStart + off;
-                               it.index = indexStart;
-                               it.dataIndex = dataIndexStart;
-                               preAV[offOut + data[it.dataIndex]] += 
vals[it.offset];
-                               while(it.dataIndex < maxId) {
-                                       it.next();
-                                       preAV[offOut + data[it.dataIndex]] += 
vals[it.offset];
+                               itb.offset = offsetStart + off;
+                               itb.index = indexStart;
+                               itb.dataIndex = dataIndexStart;
+                               preAV[offOut + data[itb.dataIndex]] += 
vals[itb.offset];
+                               while(itb.dataIndex < maxId) {
+                                       itb.next();
+                                       preAV[offOut + data[itb.dataIndex]] += 
vals[itb.offset];
                                }
                        }
                }
@@ -313,12 +289,20 @@ public class OffsetChar extends AOffset {
 
        private class IterateCharOffset extends AIterator {
 
+               protected int index;
+               protected int dataIndex;
+
                private IterateCharOffset() {
-                       super(0, 0, offsetToFirst);
+                       super(offsetToFirst);
+                       index = 0;
+                       dataIndex = 0;
                }
 
                private IterateCharOffset(int index, int dataIndex, int offset) 
{
-                       super(index, dataIndex, offset);
+                       super(offset);
+                       this.index = index;
+                       this.dataIndex = dataIndex;
+
                }
 
                @Override
@@ -350,5 +334,15 @@ public class OffsetChar extends AOffset {
                public IterateCharOffset clone() {
                        return new IterateCharOffset(index, dataIndex, offset);
                }
+
+               @Override
+               public int getDataIndex() {
+                       return dataIndex;
+               }
+
+               @Override
+               public int getOffsetsIndex() {
+                       return index;
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
index df1c285..a1ac8ff 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
@@ -147,8 +147,7 @@ public class CLALibDecompress {
                                ret.allocateDenseBlock();
                }
 
-               // final int blklen = Math.max(nRows / (k * 2), 512);
-               final int blklen = Math.max(nRows / k , 512);
+               final int blklen = Math.max(nRows / k, 512);
 
                // check if we are using filtered groups, and if we are not 
force constV to null
                if(groups == filteredGroups)
@@ -171,7 +170,7 @@ public class CLALibDecompress {
                        ret.setNonZeros(nonZeros);
                }
                else
-                       decompressDenseMultiThread(ret, filteredGroups, nRows, 
blklen, constV, eps, overlapping, k);
+                       decompressDenseMultiThread(ret, filteredGroups, nRows, 
blklen, constV, eps, k);
 
                ret.examSparsity();
                return ret;
@@ -232,14 +231,27 @@ public class CLALibDecompress {
                }
        }
 
+       protected static void decompressDenseMultiThread(MatrixBlock ret, 
List<AColGroup> groups, double[] constV, int k) {
+               final int nRows = ret.getNumRows();
+               final double eps = getEps(constV);
+               final int blklen = Math.max(nRows / k, 512);
+               decompressDenseMultiThread(ret, groups, nRows, blklen, constV, 
eps, k);
+       }
+
+       protected static void decompressDenseMultiThread(MatrixBlock ret, 
List<AColGroup> groups, double[] constV,
+               double eps, int k) {
+               final int nRows = ret.getNumRows();
+               final int blklen = Math.max(nRows / k, 512);
+               decompressDenseMultiThread(ret, groups, nRows, blklen, constV, 
eps, k);
+       }
+
        private static void decompressDenseMultiThread(MatrixBlock ret, 
List<AColGroup> filteredGroups, int rlen, int blklen,
-               double[] constV, double eps, boolean overlapping, int k) {
+               double[] constV, double eps, int k) {
                try {
                        final ExecutorService pool = CommonThreadPool.get(k);
                        final ArrayList<DecompressDenseTask> tasks = new 
ArrayList<>();
                        for(int i = 0; i < rlen; i += blklen)
-                               tasks.add(
-                                       new DecompressDenseTask(filteredGroups, 
ret, eps, i, Math.min(i + blklen, rlen), overlapping, constV));
+                               tasks.add(new 
DecompressDenseTask(filteredGroups, ret, eps, i, Math.min(i + blklen, rlen), 
constV));
 
                        long nnz = 0;
                        for(Future<Long> rt : pool.invokeAll(tasks))
@@ -299,16 +311,14 @@ public class CLALibDecompress {
                private final int _rl;
                private final int _ru;
                private final double[] _constV;
-               private final boolean _overlapping;
 
                protected DecompressDenseTask(List<AColGroup> colGroups, 
MatrixBlock ret, double eps, int rl, int ru,
-                       boolean overlapping, double[] constV) {
+                       double[] constV) {
                        _colGroups = colGroups;
                        _ret = ret;
                        _eps = eps;
                        _rl = rl;
                        _ru = ru;
-                       _overlapping = overlapping;
                        _constV = constV;
                }
 
@@ -316,14 +326,14 @@ public class CLALibDecompress {
                public Long call() {
                        final int blk = 1024;
                        long nnz = 0;
-                       for(int b = _rl; b < _ru; b+= blk){
-                               int e = Math.min(b + blk , _ru);
+                       for(int b = _rl; b < _ru; b += blk) {
+                               int e = Math.min(b + blk, _ru);
                                for(AColGroup grp : _colGroups)
                                        
grp.decompressToDenseBlock(_ret.getDenseBlock(), b, e);
 
                                if(_constV != null)
                                        addVector(_ret, _constV, _eps, b, e);
-                               nnz += _overlapping ? 0 : 
_ret.recomputeNonZeros(b, e - 1);
+                               nnz += _ret.recomputeNonZeros(b, e - 1);
                        }
 
                        return nnz;
@@ -369,24 +379,83 @@ public class CLALibDecompress {
                final int ru) {
                final int nCols = ret.getNumColumns();
                final DenseBlock db = ret.getDenseBlock();
-               if(eps == 0) {
-                       for(int row = rl; row < ru; row++) {
-                               final double[] _retV = db.values(row);
-                               final int off = db.pos(row);
-                               for(int col = 0; col < nCols; col++)
-                                       _retV[off + col] += rowV[col];
+
+               if(nCols == 1) {
+                       if(eps == 0)
+                               addValue(db.values(0), rowV[0], rl, ru);
+                       else
+                               addValueEps(db.values(0), rowV[0], eps, rl, ru);
+               }
+               else if(db.isContiguous()) {
+                       if(eps == 0)
+                               addVectorContiguousNoEps(db.values(0), rowV, 
nCols, rl, ru);
+                       else
+                               addVectorContiguousEps(db.values(0), rowV, 
nCols, eps, rl, ru);
+               }
+               else if(eps == 0)
+                       addVectorNoEps(db, rowV, nCols, rl, ru);
+               else
+                       addVectorEps(db, rowV, nCols, eps, rl, ru);
+
+       }
+
+       private static void addValue(final double[] retV, final double v, final 
int rl, final int ru) {
+               for(int off = rl; off < ru; off++)
+                       retV[off] += v;
+       }
+
+       private static void addValueEps(final double[] retV, final double v, 
final double eps, final int rl, final int ru) {
+               for(int off = rl; off < ru; off++) {
+                       final double e = retV[off] + v;
+                       if(Math.abs(e) <= eps)
+                               retV[off] = 0;
+                       else
+                               retV[off] = e;
+               }
+       }
+
+       private static void addVectorContiguousNoEps(final double[] retV, final 
double[] rowV, final int nCols, final int rl,
+               final int ru) {
+               for(int off = rl * nCols; off < ru * nCols; off += nCols) {
+                       for(int col = 0; col < nCols; col++) {
+                               final int out = off + col;
+                               retV[out] += rowV[col];
                        }
                }
-               else {
-                       for(int row = rl; row < ru; row++) {
-                               final double[] _retV = db.values(row);
-                               final int off = db.pos(row);
-                               for(int col = 0; col < nCols; col++) {
-                                       final int out = off + col;
-                                       _retV[out] += rowV[col];
-                                       if(Math.abs(_retV[out]) <= eps)
-                                               _retV[out] = 0;
-                               }
+       }
+
+       private static void addVectorContiguousEps(final double[] retV, final 
double[] rowV, final int nCols,
+               final double eps, final int rl, final int ru) {
+               for(int off = rl * nCols; off < ru * nCols; off += nCols) {
+                       for(int col = 0; col < nCols; col++) {
+                               final int out = off + col;
+                               retV[out] += rowV[col];
+                               if(Math.abs(retV[out]) <= eps)
+                                       retV[out] = 0;
+                       }
+               }
+       }
+
+       private static void addVectorNoEps(final DenseBlock db, final double[] 
rowV, final int nCols, final int rl,
+               final int ru) {
+               for(int row = rl; row < ru; row++) {
+                       final double[] _retV = db.values(row);
+                       final int off = db.pos(row);
+                       for(int col = 0; col < nCols; col++)
+                               _retV[off + col] += rowV[col];
+               }
+       }
+
+       private static void addVectorEps(final DenseBlock db, final double[] 
rowV, final int nCols, final double eps,
+               final int rl, final int ru) {
+               for(int row = rl; row < ru; row++) {
+                       final double[] _retV = db.values(row);
+                       final int off = db.pos(row);
+                       for(int col = 0; col < nCols; col++) {
+                               final int out = off + col;
+                               _retV[out] += rowV[col];
+                               if(Math.abs(_retV[out]) <= eps)
+                                       _retV[out] = 0;
                        }
                }
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index d46e071..9d6b889 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -64,7 +64,8 @@ public class CLALibLeftMultBy {
                int k) {
                if(left.isEmpty() || right.isEmpty())
                        return prepareEmptyReturnMatrix(right, left, ret, true);
-               LOG.warn("Transposing matrix block for transposed left matrix 
multiplication");
+               if(left.getNumColumns() > 1)
+                       LOG.warn("Transposing matrix block for transposed left 
matrix multiplication");
                MatrixBlock transposed = new MatrixBlock(left.getNumColumns(), 
left.getNumRows(), false);
                LibMatrixReorg.transpose(left, transposed, k);
                ret = leftMultByMatrix(right, transposed, ret, k);
@@ -276,12 +277,11 @@ public class CLALibLeftMultBy {
 
        private static void LMMParallel(List<AColGroup> filteredGroups, 
MatrixBlock that, MatrixBlock ret, double[] rowSums,
                boolean overlapping, int k) {
-               LOG.debug("Parallel left matrix multiplication thatRows: " + 
that.getNumRows());
                try {
                        final ExecutorService pool = CommonThreadPool.get(k);
                        final ArrayList<Callable<MatrixBlock>> tasks = new 
ArrayList<>();
                        final int rl = that.getNumRows();
-                       final int numberSplits = 
Math.max((filteredGroups.size() / k), 1);
+                       final int numberSplits = 
Math.min(filteredGroups.size(), k);
                        final int rowBlockThreads = Math.max(k / numberSplits, 
1);
                        final int rowBlockSize = rl <= rowBlockThreads ? 1 : 
Math.min(Math.max(rl / rowBlockThreads, 1), 16);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
new file mode 100644
index 0000000..5dfe25f
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import java.util.List;
+
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.lops.MapMultChain.ChainType;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+
+public class CLALibMMChain {
+
+       public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock 
v, MatrixBlock w, MatrixBlock out,
+               ChainType ctype, int k) {
+
+               if(x.isEmpty())
+                       return returnEmpty(x, out);
+
+               // Morph the columns to effecient types for the operation.
+               x = filterColGroups(x);
+
+               // Allow overlapping intermediate if the intermediate is 
guaranteed not to be overlapping.
+               final boolean allowOverlap = x.getColGroups().size() == 1 && 
isOverlappingAllowed();
+
+               // Right hand side multiplication
+               MatrixBlock tmp = CLALibRightMultBy.rightMultByMatrix(x, v, 
null, k, allowOverlap);
+
+               if(ctype == ChainType.XtwXv) // Multiply intermediate with 
vector if needed
+                       tmp = binaryMultW(tmp, w, k);
+
+               if(tmp instanceof CompressedMatrixBlock)
+                       // Compressed Compressed Matrix Multiplication
+                       CLALibLeftMultBy.leftMultByMatrixTransposed(x, 
(CompressedMatrixBlock) tmp, out, k);
+               else
+                       // LMM with Compressed - uncompressed multiplication.
+                       CLALibLeftMultBy.leftMultByMatrixTransposed(x, tmp, 
out, k);
+
+               if(out.getNumColumns() != 1) // transpose the output to make it 
a row output if needed
+                       out = LibMatrixReorg.transposeInPlace(out, k);
+
+               return out;
+       }
+
+       private static boolean isOverlappingAllowed() {
+               return 
ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_OVERLAPPING);
+       }
+
+       private static MatrixBlock returnEmpty(CompressedMatrixBlock x, 
MatrixBlock out) {
+               out = prepareReturn(x, out);
+               return out;
+       }
+
+       private static MatrixBlock prepareReturn(CompressedMatrixBlock x, 
MatrixBlock out) {
+               final int clen = x.getNumColumns();
+               if(out != null)
+                       out.reset(clen, 1, false);
+               else
+                       out = new MatrixBlock(clen, 1, false);
+               return out;
+       }
+
+       private static MatrixBlock binaryMultW(MatrixBlock tmp, MatrixBlock w, 
int k) {
+               final BinaryOperator bop = new 
BinaryOperator(Multiply.getMultiplyFnObject(), k);
+               if(tmp instanceof CompressedMatrixBlock)
+                       tmp = CLALibBinaryCellOp.binaryOperationsRight(bop, 
(CompressedMatrixBlock) tmp, w, null);
+               else
+                       LibMatrixBincell.bincellOpInPlace(tmp, w, bop);
+               return tmp;
+       }
+
+       private static CompressedMatrixBlock 
filterColGroups(CompressedMatrixBlock x) {
+               final List<AColGroup> groups = x.getColGroups();
+               final boolean shouldFilter = 
CLALibUtils.shouldPreFilter(groups);
+               if(shouldFilter) {
+                       final int nCol = x.getNumColumns();
+                       final double[] constV = new double[nCol];
+                       final List<AColGroup> filteredGroups = 
CLALibUtils.filterGroups(groups, constV);
+
+                       AColGroup c = ColGroupFactory.genColGroupConst(constV);
+                       filteredGroups.add(c);
+                       x.allocateColGroupList(filteredGroups);
+                       return x;
+               }
+               else
+                       return x;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
index 3ebdd3a..be43e58 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
@@ -30,15 +30,19 @@ import java.util.concurrent.Future;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.apache.sysds.utils.DMLCompressionStatistics;
 
 public class CLALibRightMultBy {
        private static final Log LOG = 
LogFactory.getLog(CLALibRightMultBy.class.getName());
@@ -52,39 +56,45 @@ public class CLALibRightMultBy {
        public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, int k,
                boolean allowOverlap) {
 
-               if(m2.isEmpty()) {
+               final int rr = m1.getNumRows();
+               final int rc = m2.getNumColumns();
+
+               if(m1.isEmpty() || m2.isEmpty()) {
                        LOG.trace("Empty right multiply");
                        if(ret == null)
-                               ret = new MatrixBlock(m1.getNumRows(), 
m2.getNumColumns(), 0);
+                               ret = new MatrixBlock(rr, rc, 0);
                        else
-                               ret.reset(m1.getNumRows(), m2.getNumColumns(), 
0);
+                               ret.reset(rr, rc, 0);
+                       return ret;
                }
                else {
                        if(m2 instanceof CompressedMatrixBlock)
                                m2 = ((CompressedMatrixBlock) 
m2).getUncompressed("Uncompressed right side of right MM");
 
-                       ret = rightMultByMatrixOverlapping(m1, m2, k);
-
-                       if(ret instanceof CompressedMatrixBlock) {
-                               if(!allowOverlap)
-                                       ret = ((CompressedMatrixBlock) 
ret).getUncompressed("Overlapping not allowed");
-                               else {
-                                       final double compressedSize = 
ret.getInMemorySize();
-                                       final double uncompressedSize = 
MatrixBlock.estimateSizeDenseInMemory(ret.getNumRows(),
-                                               ret.getNumColumns());
-                                       if(compressedSize > uncompressedSize)
-                                               ret = ((CompressedMatrixBlock) 
ret).getUncompressed(
-                                                       "Overlapping rep to 
big: " + compressedSize + " vs Uncompressed " + uncompressedSize);
-                               }
+                       if(!allowOverlap) {
+                               LOG.trace("Overlapping output not allowed in 
call to Right MM");
+                               return RMM(m1, m2, k);
                        }
-               }
 
-               ret.recomputeNonZeros();
+                       final CompressedMatrixBlock retC = RMMOverlapping(m1, 
m2, k);
+                       final double cs = retC.getInMemorySize();
+                       final double us = 
MatrixBlock.estimateSizeDenseInMemory(rr, rc);
+                       if(cs > us)
+                               return retC.getUncompressed("Overlapping rep to 
big: " + cs + " vs uncompressed " + us);
+                       else if(retC.isEmpty())
+                               return retC;
+                       else {
+                               if(retC.isOverlapping())
+                                       retC.setNonZeros((long) rr * rc); // 
set non zeros to fully dense in case of overlapping.
+                               else
+                                       retC.recomputeNonZeros(); // recompute 
if non overlapping compressed out.
+                               return retC;
+                       }
+               }
 
-               return ret;
        }
 
-       private static MatrixBlock 
rightMultByMatrixOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k) 
{
+       private static CompressedMatrixBlock 
RMMOverlapping(CompressedMatrixBlock m1, MatrixBlock that, int k) {
                final int rl = m1.getNumRows();
                final int cr = that.getNumColumns();
                final int rr = that.getNumRows(); // shared dim
@@ -101,9 +111,9 @@ public class CLALibRightMultBy {
 
                boolean containsNull = false;
                if(k == 1)
-                       containsNull = 
rightMultByMatrixOverlappingSingleThread(filteredGroups, that, retCg);
+                       containsNull = RMMSingle(filteredGroups, that, retCg);
                else
-                       containsNull = 
rightMultByMatrixOverlappingMultiThread(filteredGroups, that, retCg, k);
+                       containsNull = RMMParallel(filteredGroups, that, retCg, 
k);
 
                if(constV != null) {
                        AColGroup cRet = 
ColGroupFactory.genColGroupConst(constV).rightMultByMatrix(that);
@@ -121,8 +131,58 @@ public class CLALibRightMultBy {
                return ret;
        }
 
-       private static boolean 
rightMultByMatrixOverlappingSingleThread(List<AColGroup> filteredGroups, 
MatrixBlock that,
-               List<AColGroup> retCg) {
+       private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock 
that, int k) {
+               // this version returns a decompressed result.
+               final int rl = m1.getNumRows();
+               final int cr = that.getNumColumns();
+               final int rr = that.getNumRows(); // shared dim
+               final List<AColGroup> colGroups = m1.getColGroups();
+               final List<AColGroup> retCg = new ArrayList<>();
+
+               final boolean shouldFilter = 
CLALibUtils.shouldPreFilter(colGroups);
+
+               // start allocation of output.
+               MatrixBlock ret = new MatrixBlock(rl, cr, false);
+               final Future<MatrixBlock> f = ret.allocateBlockAsync();
+
+               double[] constV = shouldFilter ? new double[rr] : null;
+               final List<AColGroup> filteredGroups = 
CLALibUtils.filterGroups(colGroups, constV);
+               if(colGroups == filteredGroups)
+                       constV = null;
+
+               if(k == 1)
+                       RMMSingle(filteredGroups, that, retCg);
+               else
+                       RMMParallel(filteredGroups, that, retCg, k);
+
+               if(constV != null) {
+                       ColGroupConst cRet = (ColGroupConst) 
ColGroupFactory.genColGroupConst(constV).rightMultByMatrix(that);
+                       constV = cRet.getValues(); // overwrite constV
+               }
+
+               final Timing time = new Timing(true);
+
+               ret = asyncRet(f);
+               CLALibDecompress.decompressDenseMultiThread(ret, retCg, constV, 
0, k);
+
+               if(DMLScript.STATISTICS) {
+                       final double t = time.stop();
+                       DMLCompressionStatistics.addDecompressTime(t, k);
+               }
+
+               return ret;
+       }
+
+       private static <T> T asyncRet(Future<T> in) {
+               try {
+                       return in.get();
+               }
+               catch(Exception e) {
+                       throw new DMLRuntimeException(e);
+               }
+       }
+
+       private static boolean RMMSingle(List<AColGroup> filteredGroups, 
MatrixBlock that, List<AColGroup> retCg) {
                boolean containsNull = false;
                for(AColGroup g : filteredGroups) {
                        AColGroup retG = g.rightMultByMatrix(that);
@@ -134,8 +194,7 @@ public class CLALibRightMultBy {
                return containsNull;
        }
 
-       private static boolean 
rightMultByMatrixOverlappingMultiThread(List<AColGroup> filteredGroups, 
MatrixBlock that,
-               List<AColGroup> retCg, int k) {
+       private static boolean RMMParallel(List<AColGroup> filteredGroups, 
MatrixBlock that, List<AColGroup> retCg, int k) {
                ExecutorService pool = CommonThreadPool.get(k);
                boolean containsNull = false;
                try {
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java 
b/src/test/java/org/apache/sysds/test/TestUtils.java
index a0ba5bf..a45ecd1 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -991,7 +991,7 @@ public class TestUtils
                final int ar = actualMatrix.getNumRows();
                final int ac = actualMatrix.getNumColumns();
                if(er != ar || ec != ac)
-                       fail("The number of rows and columns does not match in 
matrices");
+                       fail("The number of rows and columns does not match in 
matrices expected: " + er + " " + ec + " actual: " + ar + " "+ ac);
        }
 
        public static void assertEqualColsAndRows(MatrixBlock expectedMatrix, 
MatrixBlock actualMatrix, String message) {
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTestPreAggregate.java
 
b/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTestPreAggregate.java
index 0d7f882..a1cde54 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTestPreAggregate.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTestPreAggregate.java
@@ -116,12 +116,24 @@ public abstract class OffsetTestPreAggregate {
 
        @Test
        public void preAggByteMapFirstRow() {
-               preAggMapRow(0);
+               try {
+                       preAggMapRow(0);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
        }
 
        @Test
        public void preAggByteMapSecondRow() {
-               preAggMapRow(1);
+               try {
+                       preAggMapRow(1);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
        }
 
        protected abstract void preAggMapRow(int row);
@@ -295,11 +307,11 @@ public abstract class OffsetTestPreAggregate {
                }
        }
 
-       protected double[] multiRowPreAggRangeSafe(int rl, int ru){
-               try{
+       protected double[] multiRowPreAggRangeSafe(int rl, int ru) {
+               try {
                        return multiRowPreAggRange(rl, ru);
                }
-               catch(Exception e){
+               catch(Exception e) {
                        e.printStackTrace();
                        fail(e.toString());
                        return null;
@@ -321,33 +333,57 @@ public abstract class OffsetTestPreAggregate {
 
        @Test
        public void multiRowPreAggRangeBeforeLast01() {
-               if(data.length > 2) {
-                       double[] agg = multiRowPreAggRangeBeforeLast(1, 3);
-                       compareMultiRowAggBeforeLast(agg, 1, 3);
+               try {
+                       if(data.length > 2) {
+                               double[] agg = multiRowPreAggRangeBeforeLast(1, 
3);
+                               compareMultiRowAggBeforeLast(agg, 1, 3);
+                       }
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
                }
        }
 
        @Test
        public void multiRowPreAggRangeBeforeLast02() {
-               if(data.length > 2) {
-                       double[] agg = multiRowPreAggRangeBeforeLast(2, 4);
-                       compareMultiRowAggBeforeLast(agg, 2, 4);
+               try {
+                       if(data.length > 2) {
+                               double[] agg = multiRowPreAggRangeBeforeLast(2, 
4);
+                               compareMultiRowAggBeforeLast(agg, 2, 4);
+                       }
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
                }
        }
 
        @Test
        public void multiRowPreAggRangeBeforeLast03() {
-               if(data.length > 2) {
-                       double[] agg = multiRowPreAggRangeBeforeLast(0, 4);
-                       compareMultiRowAggBeforeLast(agg, 0, 4);
+               try {
+                       if(data.length > 2) {
+                               double[] agg = multiRowPreAggRangeBeforeLast(0, 
4);
+                               compareMultiRowAggBeforeLast(agg, 0, 4);
+                       }
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
                }
        }
 
        @Test
        public void multiRowPreAggRangeBeforeLast04() {
-               if(data.length > 2) {
-                       double[] agg = multiRowPreAggRangeBeforeLast(0, 3);
-                       compareMultiRowAggBeforeLast(agg, 0, 3);
+               try {
+                       if(data.length > 2) {
+                               double[] agg = multiRowPreAggRangeBeforeLast(0, 
3);
+                               compareMultiRowAggBeforeLast(agg, 0, 3);
+                       }
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
                }
        }
 
@@ -360,7 +396,7 @@ public abstract class OffsetTestPreAggregate {
                        if(agg[of * 2 + 1] != v)
                                fail("\naggregate to wrong index");
                        if(!Precision.equals(agg[of * 2], s[r] - v - v2, eps))
-                               fail("\naggregate result is not sum minus 
value:" + agg[of * 2] + " vs " + (s[r] - v- v2));
+                               fail("\naggregate result is not sum minus 
value:" + agg[of * 2] + " vs " + (s[r] - v - v2));
                }
        }
 

Reply via email to