This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 5303344  [SYSTEMDS-3249] CLA Decouple row Aggregate dict from column 
group
5303344 is described below

commit 53033442231ce747e7855471d1ae67439cafc8a4
Author: baunsgaard <baunsga...@tugraz.at>
AuthorDate: Mon Dec 13 20:40:29 2021 +0100

    [SYSTEMDS-3249] CLA Decouple row Aggregate dict from column group
    
    This commit change row aggregates in compressed space to first aggregate
    the dictionaries into vectors, and then populating the output.
    This makes it reduce the number of floating point operations needed.
    
    In the worst dataset Infini-MNIST it half execution time from 100
    to 50 sec (uncompressed is at 8 sec) for 100 repetitions, and more
    importantly it reduce the number of instructions by 10x. Hopefully
    some better caching and threading will reduce it further in another
    commit.
    
    Closes #1483
---
 .../compress/colgroup/AColGroupCompressed.java     |  86 +++++----
 .../runtime/compress/colgroup/AColGroupValue.java  |  22 ++-
 .../runtime/compress/colgroup/ColGroupConst.java   |  40 +++--
 .../runtime/compress/colgroup/ColGroupDDC.java     |  18 +-
 .../runtime/compress/colgroup/ColGroupEmpty.java   |  31 +++-
 .../runtime/compress/colgroup/ColGroupOLE.java     |   9 +-
 .../runtime/compress/colgroup/ColGroupPFOR.java    |  43 +++--
 .../runtime/compress/colgroup/ColGroupRLE.java     | 198 ++++++++++-----------
 .../runtime/compress/colgroup/ColGroupSDC.java     |  64 +------
 .../compress/colgroup/ColGroupSDCSingle.java       |  16 +-
 .../compress/colgroup/ColGroupSDCSingleZeros.java  |  16 +-
 .../compress/colgroup/ColGroupSDCZeros.java        |  28 +--
 .../compress/colgroup/dictionary/ADictionary.java  |   8 +
 .../compress/colgroup/dictionary/Dictionary.java   |  20 +++
 .../colgroup/dictionary/MatrixBlockDictionary.java |  46 +++++
 .../compress/colgroup/dictionary/QDictionary.java  |   5 +
 .../sysds/runtime/compress/lib/CLALibCompAgg.java  |  72 +++++---
 17 files changed, 404 insertions(+), 318 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java
index 90cd5c9..1fbb843 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.java
@@ -57,22 +57,48 @@ public abstract class AColGroupCompressed extends AColGroup 
{
 
        protected abstract void computeSum(double[] c, int nRows);
 
-       protected abstract void computeRowSums(double[] c, int rl, int ru);
-
        protected abstract void computeSumSq(double[] c, int nRows);
 
-       protected abstract void computeRowSumsSq(double[] c, int rl, int ru);
-
        protected abstract void computeColSumsSq(double[] c, int nRows);
 
-       protected abstract void computeRowMxx(double[] c, Builtin builtin, int 
rl, int ru);
+       protected abstract void computeRowSums(double[] c, int rl, int ru, 
double[] preAgg);
+
+       protected abstract void computeRowMxx(double[] c, Builtin builtin, int 
rl, int ru, double[] preAgg);
 
        protected abstract void computeProduct(double[] c, int nRows);
 
-       protected abstract void computeRowProduct(double[] c, int rl, int ru);
+       protected abstract void computeRowProduct(double[] c, int rl, int ru, 
double[] preAgg);
 
        protected abstract void computeColProduct(double[] c, int nRows);
 
+       protected abstract double[] preAggSumRows();
+
+       protected abstract double[] preAggSumSqRows();
+
+       protected abstract double[] preAggProductRows();
+
+       protected abstract double[] preAggBuiltinRows(Builtin builtin);
+
+       public double[] preAggRows(AggregateUnaryOperator op) {
+               final ValueFunction fn = op.aggOp.increOp.fn;
+               if(fn instanceof KahanPlusSq)
+                       return preAggSumSqRows();
+               else if(fn instanceof Plus || fn instanceof KahanPlus)
+                       return preAggSumRows();
+               else if(fn instanceof Multiply)
+                       return preAggProductRows();
+               else if(fn instanceof Builtin) {
+                       Builtin bop = (Builtin) fn;
+                       BuiltinCode bopC = bop.getBuiltinCode();
+                       if(bopC == BuiltinCode.MAX || bopC == BuiltinCode.MIN)
+                               return preAggBuiltinRows(bop);
+                       else
+                               throw new DMLScriptException("unsupported 
builtin type: " + bop);
+               }
+               else
+                       throw new DMLScriptException("Unknown UnaryAggregate 
operator on CompressedMatrixBlock " + op);
+       }
+
        @Override
        public double getMin() {
                return computeMxx(Double.POSITIVE_INFINITY, 
Builtin.getBuiltinFnObject(BuiltinCode.MIN));
@@ -85,31 +111,33 @@ public abstract class AColGroupCompressed extends 
AColGroup {
 
        @Override
        public final void unaryAggregateOperations(AggregateUnaryOperator op, 
double[] c, int nRows, int rl, int ru) {
+               unaryAggregateOperations(op, c, nRows, rl, ru, null);
+       }
+
+       public final void unaryAggregateOperations(AggregateUnaryOperator op, 
double[] c, int nRows, int rl, int ru,
+               double[] preAgg) {
                final ValueFunction fn = op.aggOp.increOp.fn;
-               if(fn instanceof Plus || fn instanceof KahanPlus || fn 
instanceof KahanPlusSq) {
-                       boolean square = fn instanceof KahanPlusSq;
-                       if(square){
-                               if(op.indexFn instanceof ReduceAll)
-                                       computeSumSq(c, nRows);
-                               else if(op.indexFn instanceof ReduceCol)
-                                       computeRowSumsSq(c, rl, ru);
-                               else if(op.indexFn instanceof ReduceRow)
-                                       computeColSumsSq(c, nRows);
-                       }
-                       else{
-                               if(op.indexFn instanceof ReduceAll)
-                                       computeSum(c, nRows);
-                               else if(op.indexFn instanceof ReduceCol)
-                                       computeRowSums(c, rl, ru);
-                               else if(op.indexFn instanceof ReduceRow)
-                                       computeColSums(c, nRows);
-                       }
+               if(fn instanceof KahanPlusSq) {
+                       if(op.indexFn instanceof ReduceAll)
+                               computeSumSq(c, nRows);
+                       else if(op.indexFn instanceof ReduceCol)
+                               computeRowSums(c, rl, ru, preAgg);
+                       else if(op.indexFn instanceof ReduceRow)
+                               computeColSumsSq(c, nRows);
+               }
+               else if(fn instanceof Plus || fn instanceof KahanPlus) {
+                       if(op.indexFn instanceof ReduceAll)
+                               computeSum(c, nRows);
+                       else if(op.indexFn instanceof ReduceCol)
+                               computeRowSums(c, rl, ru, preAgg);
+                       else if(op.indexFn instanceof ReduceRow)
+                               computeColSums(c, nRows);
                }
                else if(fn instanceof Multiply) {
                        if(op.indexFn instanceof ReduceAll)
                                computeProduct(c, nRows);
                        else if(op.indexFn instanceof ReduceCol)
-                               computeRowProduct(c, rl, ru);
+                               computeRowProduct(c, rl, ru, preAgg);
                        else if(op.indexFn instanceof ReduceRow)
                                computeColProduct(c, nRows);
                }
@@ -120,17 +148,15 @@ public abstract class AColGroupCompressed extends 
AColGroup {
                                if(op.indexFn instanceof ReduceAll)
                                        c[0] = computeMxx(c[0], bop);
                                else if(op.indexFn instanceof ReduceCol)
-                                       computeRowMxx(c, bop, rl, ru);
+                                       computeRowMxx(c, bop, rl, ru, preAgg);
                                else if(op.indexFn instanceof ReduceRow)
                                        computeColMxx(c, bop);
                        }
-                       else {
+                       else
                                throw new DMLScriptException("unsupported 
builtin type: " + bop);
-                       }
                }
-               else {
+               else
                        throw new DMLScriptException("Unknown UnaryAggregate 
operator on CompressedMatrixBlock");
-               }
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
index b38bae1..4e44191 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java
@@ -350,11 +350,31 @@ public abstract class AColGroupValue extends 
AColGroupCompressed implements Clon
        }
 
        @Override
-       protected void computeRowProduct(double[] c, int rl, int ru) {
+       protected void computeRowProduct(double[] c, int rl, int ru, double[] 
preAgg) {
                throw new NotImplementedException();
        }
 
        @Override
+       protected double[] preAggSumRows(){
+               return _dict.sumAllRowsToDouble(_colIndexes.length);
+       }
+
+       @Override
+       protected double[] preAggSumSqRows(){
+               return _dict.sumAllRowsToDoubleSq(_colIndexes.length);
+       }
+
+       @Override
+       protected double[] preAggProductRows(){
+               throw new NotImplementedException();
+       }
+
+       @Override
+       protected double[] preAggBuiltinRows(Builtin builtin){
+               return _dict.aggregateRows(builtin, _colIndexes.length);
+       }
+
+       @Override
        protected void computeColProduct(double[] c, int nRows) {
                _dict.colProduct(c, getCounts(), _colIndexes);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
index fbc510c..45e4afc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java
@@ -75,10 +75,10 @@ public class ColGroupConst extends AColGroupCompressed {
        }
 
        @Override
-       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru) {
-               double value = _dict.aggregateRows(builtin, 
_colIndexes.length)[0];
+       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru, double[] preAgg) {
+               double v = preAgg[0];
                for(int i = rl; i < ru; i++)
-                       c[i] = builtin.execute(c[i], value);
+                       c[i] = builtin.execute(c[i], v);
        }
 
        @Override
@@ -177,15 +177,8 @@ public class ColGroupConst extends AColGroupCompressed {
        }
 
        @Override
-       protected void computeRowSums(double[] c, int rl, int ru) {
-               double vals = _dict.sumAllRowsToDouble(_colIndexes.length)[0];
-               for(int rix = rl; rix < ru; rix++)
-                       c[rix] += vals;
-       }
-
-       @Override
-       protected void computeRowSumsSq(double[] c, int rl, int ru) {
-               double vals = _dict.sumAllRowsToDoubleSq(_colIndexes.length)[0];
+       protected void computeRowSums(double[] c, int rl, int ru, double[] 
preAgg) {
+               double vals = preAgg[0];
                for(int rix = rl; rix < ru; rix++)
                        c[rix] += vals;
        }
@@ -323,7 +316,7 @@ public class ColGroupConst extends AColGroupCompressed {
        }
 
        @Override
-       protected void computeRowProduct(double[] c, int rl, int ru) {
+       protected void computeRowProduct(double[] c, int rl, int ru, double[] 
preAgg) {
                throw new NotImplementedException();
        }
 
@@ -332,4 +325,25 @@ public class ColGroupConst extends AColGroupCompressed {
                throw new NotImplementedException();
 
        }
+
+       @Override
+       protected double[] preAggSumRows() {
+               return _dict.sumAllRowsToDouble(_colIndexes.length);
+       }
+
+       @Override
+       protected double[] preAggSumSqRows() {
+               return _dict.sumAllRowsToDoubleSq(_colIndexes.length);
+
+       }
+
+       @Override
+       protected double[] preAggProductRows() {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       protected double[] preAggBuiltinRows(Builtin builtin) {
+               return _dict.aggregateRows(builtin, _colIndexes.length);
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
index 1651a8b..c7e2a34 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java
@@ -105,25 +105,15 @@ public class ColGroupDDC extends APreAgg {
        }
 
        @Override
-       protected void computeRowSums(double[] c, int rl, int ru) {
-               double[] vals = _dict.sumAllRowsToDouble(_colIndexes.length);
+       protected void computeRowSums(double[] c, int rl, int ru, double[] 
preAgg) {
                for(int rix = rl; rix < ru; rix++)
-                       c[rix] += vals[_data.getIndex(rix)];
+                       c[rix] += preAgg[_data.getIndex(rix)];
        }
 
        @Override
-       protected void computeRowSumsSq(double[] c, int rl, int ru) {
-               double[] vals = _dict.sumAllRowsToDoubleSq(_colIndexes.length);
-               for(int rix = rl; rix < ru; rix++)
-                       c[rix] += vals[_data.getIndex(rix)];
-       }
-
-       @Override
-       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru) {
-               final int nCol = getNumCols();
-               double[] preAggregatedRows = _dict.aggregateRows(builtin, nCol);
+       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru, double[] preAgg) {
                for(int i = rl; i < ru; i++)
-                       c[i] = builtin.execute(c[i], 
preAggregatedRows[_data.getIndex(i)]);
+                       c[i] = builtin.execute(c[i], preAgg[_data.getIndex(i)]);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java
index a75f046..bcd139a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java
@@ -206,7 +206,7 @@ public class ColGroupEmpty extends AColGroupCompressed {
        }
 
        @Override
-       protected void computeRowSums(double[] c, int rl, int ru) {
+       protected void computeRowSums(double[] c, int rl, int ru, double[] 
preAgg) {
                // do nothing
        }
 
@@ -221,17 +221,12 @@ public class ColGroupEmpty extends AColGroupCompressed {
        }
 
        @Override
-       protected void computeRowSumsSq(double[] c, int rl, int ru) {
-               // do nothing
-       }
-
-       @Override
        protected void computeColSumsSq(double[] c, int nRows) {
                // do nothing
        }
 
        @Override
-       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru) {
+       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru, double[] preAgg) {
                for(int r = rl; r < ru; r++)
                        c[r] = builtin.execute(c[r], 0);
        }
@@ -247,7 +242,7 @@ public class ColGroupEmpty extends AColGroupCompressed {
        }
 
        @Override
-       protected void computeRowProduct(double[] c, int rl, int ru) {
+       protected void computeRowProduct(double[] c, int rl, int ru, double[] 
preAgg) {
                // do nothing
        }
 
@@ -255,4 +250,24 @@ public class ColGroupEmpty extends AColGroupCompressed {
        protected void computeColProduct(double[] c, int nRows) {
                // do nothing
        }
+
+       @Override
+       protected double[] preAggSumRows() {
+               return null;
+       }
+
+       @Override
+       protected double[] preAggSumSqRows() {
+               return null;
+       }
+
+       @Override
+       protected double[] preAggProductRows() {
+               return null;
+       }
+
+       @Override
+       protected double[] preAggBuiltinRows(Builtin builtin) {
+               return null;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java
index 285c710..67b364a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java
@@ -217,7 +217,7 @@ public class ColGroupOLE extends AColGroupOffset {
        // }
 
        @Override
-       protected void computeRowSums(double[] c, int rl, int ru) {
+       protected void computeRowSums(double[] c, int rl, int ru, double[] 
preAgg) {
                throw new NotImplementedException();
                // final int blksz = CompressionSettings.BITMAP_BLOCK_SZ;
                // final int numVals = getNumValues();
@@ -283,12 +283,7 @@ public class ColGroupOLE extends AColGroupOffset {
        }
 
        @Override
-       protected void computeRowSumsSq(double[] c, int rl, int ru) {
-               throw new NotImplementedException();
-       }
-
-       @Override
-       protected final void computeRowMxx(double[] c, Builtin builtin, int rl, 
int ru) {
+       protected final void computeRowMxx(double[] c, Builtin builtin, int rl, 
int ru, double[] preAgg) {
                // NOTE: zeros handled once for all column groups outside
                final int blksz = CompressionSettings.BITMAP_BLOCK_SZ;
                final int numVals = getNumValues();
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupPFOR.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupPFOR.java
index 99fa68e..a39b17f 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupPFOR.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupPFOR.java
@@ -121,17 +121,6 @@ public class ColGroupPFOR extends AMorphingMMColGroup {
                return _data.getCounts(counts, _numRows);
        }
 
-       @Override
-       protected void computeRowSums(double[] c, int rl, int ru) {
-               // Add reference value sum.
-               final double refSum = refSum();
-               for(int rix = rl; rix < ru; rix++)
-                       c[rix] += refSum;
-
-               final double[] vals = 
_dict.sumAllRowsToDouble(_colIndexes.length);
-               ColGroupSDCZeros.computeRowSums(c, rl, ru, vals, _data, 
_indexes, _numRows);
-       }
-
        private final double refSum() {
                double ret = 0;
                for(double d : _reference)
@@ -140,15 +129,13 @@ public class ColGroupPFOR extends AMorphingMMColGroup {
        }
 
        @Override
-       protected void computeRowSumsSq(double[] c, int rl, int ru) {
-               final double[] vals = _dict.sumAllRowsToDoubleSq(_reference);
-               ColGroupSDC.computeRowSumsSq(c, rl, ru, vals, _data, _indexes, 
_numRows);
+       protected void computeRowSums(double[] c, int rl, int ru, double[] 
preAgg) {
+               ColGroupSDC.computeRowSums(c, rl, ru, preAgg, _data, _indexes, 
_numRows);
        }
 
        @Override
-       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru) {
-               final double[] vals = _dict.aggregateRows(builtin, _reference);
-               ColGroupSDC.computeRowMxx(c, builtin, rl, ru, vals, _data, 
_indexes, _numRows, vals[vals.length - 1]);
+       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru, double[] preAgg) {
+               ColGroupSDC.computeRowMxx(c, builtin, rl, ru, preAgg, _data, 
_indexes, _numRows, preAgg[preAgg.length - 1]);
        }
 
        @Override
@@ -310,12 +297,32 @@ public class ColGroupPFOR extends AMorphingMMColGroup {
        }
 
        @Override
+       protected double[] preAggSumRows() {
+               return _dict.sumAllRowsToDouble(_reference);
+       }
+
+       @Override
+       protected double[] preAggSumSqRows() {
+               return _dict.sumAllRowsToDoubleSq(_reference);
+       }
+
+       @Override
+       protected double[] preAggProductRows() {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       protected double[] preAggBuiltinRows(Builtin builtin) {
+               return _dict.aggregateRows(builtin, _reference);
+       }
+
+       @Override
        protected void computeProduct(double[] c, int nRows) {
                throw new NotImplementedException("Not Implemented PFOR");
        }
 
        @Override
-       protected void computeRowProduct(double[] c, int rl, int ru) {
+       protected void computeRowProduct(double[] c, int rl, int ru, double[] 
preAgg) {
                throw new NotImplementedException("Not Implemented PFOR");
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java
index 4a24a07..ed28968 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java
@@ -29,7 +29,6 @@ import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.Pair;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
 
@@ -217,7 +216,7 @@ public class ColGroupRLE extends AColGroupOffset {
        // }
 
        @Override
-       protected void computeRowSums(double[] c, int rl, int ru) {
+       protected void computeRowSums(double[] c, int rl, int ru, double[] 
preAgg) {
                throw new NotImplementedException();
                // final int numVals = getNumValues();
 
@@ -286,60 +285,89 @@ public class ColGroupRLE extends AColGroupOffset {
                // }
        }
 
+       // @Override
+       // protected void computeRowSumsSq(double[] c, int rl, int ru, double[] 
preAgg) {
+       // throw new NotImplementedException();
+       // // final int numVals = getNumValues();
+
+       // // if(numVals > 1 && _numRows > CompressionSettings.BITMAP_BLOCK_SZ) 
{
+       // // final int blksz = CompressionSettings.BITMAP_BLOCK_SZ;
+
+       // // // step 1: prepare position and value arrays
+
+       // // // current pos / values per RLE list
+       // // int[] astart = new int[numVals];
+       // // int[] apos = skipScan(numVals, rl, astart);
+       // // double[] aval = _dict.sumAllRowsToDouble(square, 
_colIndexes.length);
+
+       // // // step 2: cache conscious matrix-vector via horizontal scans
+       // // for(int bi = rl; bi < ru; bi += blksz) {
+       // // int bimax = Math.min(bi + blksz, ru);
+
+       // // // horizontal segment scan, incl pos maintenance
+       // // for(int k = 0; k < numVals; k++) {
+       // // int boff = _ptr[k];
+       // // int blen = len(k);
+       // // double val = aval[k];
+       // // int bix = apos[k];
+       // // int start = astart[k];
+
+       // // // compute partial results, not aligned
+       // // while(bix < blen) {
+       // // int lstart = _data[boff + bix];
+       // // int llen = _data[boff + bix + 1];
+       // // int from = Math.max(bi, start + lstart);
+       // // int to = Math.min(start + lstart + llen, bimax);
+       // // for(int rix = from; rix < to; rix++)
+       // // c[rix] += val;
+
+       // // if(start + lstart + llen >= bimax)
+       // // break;
+       // // start += lstart + llen;
+       // // bix += 2;
+       // // }
+
+       // // apos[k] = bix;
+       // // astart[k] = start;
+       // // }
+       // // }
+       // // }
+       // // else {
+       // // for(int k = 0; k < numVals; k++) {
+       // // int boff = _ptr[k];
+       // // int blen = len(k);
+       // // double val = _dict.sumRow(k, square, _colIndexes.length);
+
+       // // if(val != 0.0) {
+       // // Pair<Integer, Integer> tmp = skipScanVal(k, rl);
+       // // int bix = tmp.getKey();
+       // // int curRunStartOff = tmp.getValue();
+       // // int curRunEnd = tmp.getValue();
+       // // for(; bix < blen && curRunEnd < ru; bix += 2) {
+       // // curRunStartOff = curRunEnd + _data[boff + bix];
+       // // curRunEnd = curRunStartOff + _data[boff + bix + 1];
+       // // for(int rix = curRunStartOff; rix < curRunEnd && rix < ru; rix++)
+       // // c[rix] += val;
+
+       // // }
+       // // }
+       // // }
+       // // }
+       // }
+
        @Override
-       protected void computeRowSumsSq(double[] c, int rl, int ru) {
+       protected final void computeRowMxx(double[] c, Builtin builtin, int rl, 
int ru, double[] preAgg) {
                throw new NotImplementedException();
+               // NOTE: zeros handled once for all column groups outside
                // final int numVals = getNumValues();
+               // // double[] c = result.getDenseBlockValues();
+               // final double[] values = _dict.getValues();
 
-               // if(numVals > 1 && _numRows > 
CompressionSettings.BITMAP_BLOCK_SZ) {
-               // final int blksz = CompressionSettings.BITMAP_BLOCK_SZ;
-
-               // // step 1: prepare position and value arrays
-
-               // // current pos / values per RLE list
-               // int[] astart = new int[numVals];
-               // int[] apos = skipScan(numVals, rl, astart);
-               // double[] aval = _dict.sumAllRowsToDouble(square, 
_colIndexes.length);
-
-               // // step 2: cache conscious matrix-vector via horizontal scans
-               // for(int bi = rl; bi < ru; bi += blksz) {
-               // int bimax = Math.min(bi + blksz, ru);
-
-               // // horizontal segment scan, incl pos maintenance
-               // for(int k = 0; k < numVals; k++) {
-               // int boff = _ptr[k];
-               // int blen = len(k);
-               // double val = aval[k];
-               // int bix = apos[k];
-               // int start = astart[k];
-
-               // // compute partial results, not aligned
-               // while(bix < blen) {
-               // int lstart = _data[boff + bix];
-               // int llen = _data[boff + bix + 1];
-               // int from = Math.max(bi, start + lstart);
-               // int to = Math.min(start + lstart + llen, bimax);
-               // for(int rix = from; rix < to; rix++)
-               // c[rix] += val;
-
-               // if(start + lstart + llen >= bimax)
-               // break;
-               // start += lstart + llen;
-               // bix += 2;
-               // }
-
-               // apos[k] = bix;
-               // astart[k] = start;
-               // }
-               // }
-               // }
-               // else {
                // for(int k = 0; k < numVals; k++) {
                // int boff = _ptr[k];
                // int blen = len(k);
-               // double val = _dict.sumRow(k, square, _colIndexes.length);
+               // double val = mxxValues(k, builtin, values);
 
-               // if(val != 0.0) {
                // Pair<Integer, Integer> tmp = skipScanVal(k, rl);
                // int bix = tmp.getKey();
                // int curRunStartOff = tmp.getValue();
@@ -348,40 +376,12 @@ public class ColGroupRLE extends AColGroupOffset {
                // curRunStartOff = curRunEnd + _data[boff + bix];
                // curRunEnd = curRunStartOff + _data[boff + bix + 1];
                // for(int rix = curRunStartOff; rix < curRunEnd && rix < ru; 
rix++)
-               // c[rix] += val;
-
-               // }
-               // }
+               // c[rix] = builtin.execute(c[rix], val);
                // }
                // }
        }
 
        @Override
-       protected final void computeRowMxx(double[] c, Builtin builtin, int rl, 
int ru) {
-               // NOTE: zeros handled once for all column groups outside
-               final int numVals = getNumValues();
-               // double[] c = result.getDenseBlockValues();
-               final double[] values = _dict.getValues();
-
-               for(int k = 0; k < numVals; k++) {
-                       int boff = _ptr[k];
-                       int blen = len(k);
-                       double val = mxxValues(k, builtin, values);
-
-                       Pair<Integer, Integer> tmp = skipScanVal(k, rl);
-                       int bix = tmp.getKey();
-                       int curRunStartOff = tmp.getValue();
-                       int curRunEnd = tmp.getValue();
-                       for(; bix < blen && curRunEnd < ru; bix += 2) {
-                               curRunStartOff = curRunEnd + _data[boff + bix];
-                               curRunEnd = curRunStartOff + _data[boff + bix + 
1];
-                               for(int rix = curRunStartOff; rix < curRunEnd 
&& rix < ru; rix++)
-                                       c[rix] = builtin.execute(c[rix], val);
-                       }
-               }
-       }
-
-       @Override
        public boolean[] computeZeroIndicatorVector() {
                boolean[] ret = new boolean[_numRows];
                final int numVals = getNumValues();
@@ -489,28 +489,28 @@ public class ColGroupRLE extends AColGroupOffset {
                return apos;
        }
 
-       private Pair<Integer, Integer> skipScanVal(int k, int rl) {
-               int apos = 0;
-               int astart = 0;
-
-               if(rl > 0) { // rl aligned with blksz
-                       int boff = _ptr[k];
-                       int blen = len(k);
-                       int bix = 0;
-                       int start = 0;
-                       while(bix < blen) {
-                               int lstart = _data[boff + bix]; // start
-                               int llen = _data[boff + bix + 1]; // len
-                               if(start + lstart + llen >= rl)
-                                       break;
-                               start += lstart + llen;
-                               bix += 2;
-                       }
-                       apos = bix;
-                       astart = start;
-               }
-               return new Pair<>(apos, astart);
-       }
+       // private Pair<Integer, Integer> skipScanVal(int k, int rl) {
+       // int apos = 0;
+       // int astart = 0;
+
+       // if(rl > 0) { // rl aligned with blksz
+       // int boff = _ptr[k];
+       // int blen = len(k);
+       // int bix = 0;
+       // int start = 0;
+       // while(bix < blen) {
+       // int lstart = _data[boff + bix]; // start
+       // int llen = _data[boff + bix + 1]; // len
+       // if(start + lstart + llen >= rl)
+       // break;
+       // start += lstart + llen;
+       // bix += 2;
+       // }
+       // apos = bix;
+       // astart = start;
+       // }
+       // return new Pair<>(apos, astart);
+       // }
 
        @Override
        public void leftMultByMatrix(MatrixBlock matrix, MatrixBlock result, 
int rl, int ru) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index fd94d0a..759252a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -43,7 +43,7 @@ import 
org.apache.sysds.runtime.matrix.operators.ScalarOperator;
  */
 public class ColGroupSDC extends AMorphingMMColGroup {
        private static final long serialVersionUID = 769993538831949086L;
-       
+
        /** Sparse row indexes for the data */
        protected transient AOffset _indexes;
        /** Pointers to row indexes in the dictionary. Note the dictionary has 
one extra entry. */
@@ -93,67 +93,22 @@ public class ColGroupSDC extends AMorphingMMColGroup {
        }
 
        @Override
-       protected void computeRowSums(double[] c, int rl, int ru) {
-
-               final AIterator it = _indexes.getIterator(rl);
-               final int numVals = getNumValues();
-               int r = rl;
-               final double[] vals = 
_dict.sumAllRowsToDouble(_colIndexes.length);
-               final double def = vals[numVals - 1];
-               if(it != null && it.value() > ru)
-                       _indexes.cacheIterator(it, ru);
-               else if(it != null && ru >= _indexes.getOffsetToLast()) {
-                       final int maxId = _data.size() - 1;
-                       while(true) {
-                               if(it.value() == r) {
-                                       c[r] += 
vals[_data.getIndex(it.getDataIndex())];
-                                       if(it.getDataIndex() < maxId)
-                                               it.next();
-                                       else {
-                                               r++;
-                                               break;
-                                       }
-                               }
-                               else
-                                       c[r] += def;
-                               r++;
-                       }
-               }
-               else if(it != null) {
-                       while(it.isNotOver(ru)) {
-                               if(it.value() == r)
-                                       c[r] += 
vals[_data.getIndex(it.getDataIndexAndIncrement())];
-                               else
-                                       c[r] += def;
-                               r++;
-                       }
-                       _indexes.cacheIterator(it, ru);
-               }
-
-               while(r < ru) {
-                       c[r] += def;
-                       r++;
-               }
-       }
-
-       @Override
-       protected void computeRowSumsSq(double[] c, int rl, int ru) {
-               final double[] vals = 
_dict.sumAllRowsToDoubleSq(_colIndexes.length);
-               computeRowSumsSq(c, rl, ru, vals, _data, _indexes, _numRows);
+       protected void computeRowSums(double[] c, int rl, int ru, double[] 
preAgg) {
+               computeRowSums(c, rl, ru, preAgg, _data, _indexes, _numRows);
        }
 
-       protected static final void computeRowSumsSq(double[] c, int rl, int 
ru, double[] vals, AMapToData data,
+       protected static final void computeRowSums(double[] c, int rl, int ru, 
double[] preAgg, AMapToData data,
                AOffset indexes, int nRows) {
                int r = rl;
                final AIterator it = indexes.getIterator(rl);
-               final double def = vals[vals.length - 1];
+               final double def = preAgg[preAgg.length - 1];
                if(it != null && it.value() > ru)
                        indexes.cacheIterator(it, ru);
                else if(it != null && ru >= indexes.getOffsetToLast()) {
                        final int maxId = data.size() - 1;
                        while(true) {
                                if(it.value() == r) {
-                                       c[r] += 
vals[data.getIndex(it.getDataIndex())];
+                                       c[r] += 
preAgg[data.getIndex(it.getDataIndex())];
                                        if(it.getDataIndex() < maxId)
                                                it.next();
                                        else {
@@ -169,7 +124,7 @@ public class ColGroupSDC extends AMorphingMMColGroup {
                else if(it != null) {
                        while(r < ru) {
                                if(it.value() == r)
-                                       c[r] += 
vals[data.getIndex(it.getDataIndexAndIncrement())];
+                                       c[r] += 
preAgg[data.getIndex(it.getDataIndexAndIncrement())];
                                else
                                        c[r] += def;
                                r++;
@@ -184,9 +139,8 @@ public class ColGroupSDC extends AMorphingMMColGroup {
        }
 
        @Override
-       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru) {
-               final double[] vals = _dict.aggregateRows(builtin, 
_colIndexes.length);
-               computeRowMxx(c, builtin, rl, ru, vals, _data, _indexes, 
_numRows, vals[vals.length - 1]);
+       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru, double[] preAgg) {
+               computeRowMxx(c, builtin, rl, ru, preAgg, _data, _indexes, 
_numRows, preAgg[preAgg.length - 1]);
        }
 
        protected static final void computeRowMxx(double[] c, Builtin builtin, 
int rl, int ru, double[] vals,
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
index a41198d..a00841d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
@@ -78,17 +78,6 @@ public class ColGroupSDCSingle extends AMorphingMMColGroup {
        }
 
        @Override
-       protected void computeRowSums(double[] c, int rl, int ru) {
-               final double[] vals = 
_dict.sumAllRowsToDouble(_colIndexes.length);
-               computeRowSums(c, rl, ru, vals);
-       }
-
-       @Override
-       protected void computeRowSumsSq(double[] c, int rl, int ru) {
-               final double[] vals = 
_dict.sumAllRowsToDoubleSq(_colIndexes.length);
-               computeRowSums(c, rl, ru, vals);
-       }
-
        protected void computeRowSums(double[] c, int rl, int ru, double[] 
vals) {
                int r = rl;
                final AIterator it = _indexes.getIterator(rl);
@@ -131,9 +120,8 @@ public class ColGroupSDCSingle extends AMorphingMMColGroup {
        }
 
        @Override
-       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru) {
-               final double[] vals = _dict.aggregateRows(builtin, 
_colIndexes.length);
-               computeRowMxx(c, builtin, rl, ru, _indexes, _numRows, vals[1], 
vals[0]);
+       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru, double[] preAgg) {
+               computeRowMxx(c, builtin, rl, ru, _indexes, _numRows, 
preAgg[1], preAgg[0]);
        }
 
        protected static final void computeRowMxx(double[] c, Builtin builtin, 
int rl, int ru, AOffset indexes, int nRows,
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
index ca2415c..dd419e5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java
@@ -213,15 +213,8 @@ public class ColGroupSDCSingleZeros extends APreAgg {
        }
 
        @Override
-       protected void computeRowSums(double[] c, int rl, int ru) {
-               final double def = 
_dict.sumAllRowsToDouble(_colIndexes.length)[0];
-               computeRowSum(c, rl, ru, def);
-       }
-
-       @Override
-       protected void computeRowSumsSq(double[] c, int rl, int ru) {
-               final double def = 
_dict.sumAllRowsToDoubleSq(_colIndexes.length)[0];
-               computeRowSum(c, rl, ru, def);
+       protected void computeRowSums(double[] c, int rl, int ru, double[] 
preAgg) {
+               computeRowSum(c, rl, ru, preAgg[0]);
        }
 
        protected void computeRowSum(double[] c, int rl, int ru, double def) {
@@ -249,9 +242,8 @@ public class ColGroupSDCSingleZeros extends APreAgg {
        }
 
        @Override
-       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru) {
-               final double[] vals = _dict.aggregateRows(builtin, 
_colIndexes.length);
-               ColGroupSDCSingle.computeRowMxx(c, builtin, rl, ru, _indexes, 
_numRows, 0, vals[0]);
+       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru, double[] preAgg) {
+               ColGroupSDCSingle.computeRowMxx(c, builtin, rl, ru, _indexes, 
_numRows, 0, preAgg[0]);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
index ee3bad4..7e39b74 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
@@ -295,22 +295,11 @@ public class ColGroupSDCZeros extends APreAgg {
        }
 
        @Override
-       protected void computeRowSums(double[] c, int rl, int ru) {
-               final double[] vals = 
_dict.sumAllRowsToDouble(_colIndexes.length);
-               computeRowSums(c, rl, ru, vals);
+       protected void computeRowSums(double[] c, int rl, int ru, double[] 
preAgg) {
+               computeRowSums(c, rl, ru, preAgg, _data, _indexes, _numRows);
        }
 
-       @Override
-       protected void computeRowSumsSq(double[] c, int rl, int ru) {
-               final double[] vals = 
_dict.sumAllRowsToDoubleSq(_colIndexes.length);
-               computeRowSums(c, rl, ru, vals);
-       }
-
-       protected void computeRowSums(double[] c, int rl, int ru, double[] 
vals) {
-               computeRowSums(c, rl, ru, vals, _data, _indexes, _numRows);
-       }
-
-       protected static final void computeRowSums(double[] c, int rl, int ru, 
double[] vals, AMapToData data,
+       protected static final void computeRowSums(double[] c, int rl, int ru, 
double[] preAgg, AMapToData data,
                AOffset indexes, int nRows) {
                final AIterator it = indexes.getIterator(rl);
                if(it == null)
@@ -319,15 +308,15 @@ public class ColGroupSDCZeros extends APreAgg {
                        indexes.cacheIterator(it, ru);
                else if(ru >= indexes.getOffsetToLast()) {
                        final int maxId = data.size() - 1;
-                       c[it.value()] += vals[data.getIndex(it.getDataIndex())];
+                       c[it.value()] += 
preAgg[data.getIndex(it.getDataIndex())];
                        while(it.getDataIndex() < maxId) {
                                it.next();
-                               c[it.value()] += 
vals[data.getIndex(it.getDataIndex())];
+                               c[it.value()] += 
preAgg[data.getIndex(it.getDataIndex())];
                        }
                }
                else {
                        while(it.isNotOver(ru)) {
-                               c[it.value()] += 
vals[data.getIndex(it.getDataIndex())];
+                               c[it.value()] += 
preAgg[data.getIndex(it.getDataIndex())];
                                it.next();
                        }
                        indexes.cacheIterator(it, ru);
@@ -335,9 +324,8 @@ public class ColGroupSDCZeros extends APreAgg {
        }
 
        @Override
-       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru) {
-               final double[] vals = _dict.aggregateRows(builtin, 
_colIndexes.length);
-               ColGroupSDC.computeRowMxx(c, builtin, rl, ru, vals, _data, 
_indexes, _numRows, 0);
+       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru, double[] preAgg) {
+               ColGroupSDC.computeRowMxx(c, builtin, rl, ru, preAgg, _data, 
_indexes, _numRows, 0);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index 7ee7ed3..cca2a89 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -289,6 +289,14 @@ public abstract class ADictionary implements Serializable {
        public abstract double[] sumAllRowsToDouble(int nrColumns);
 
        /**
+        * Method used as a pre-aggregate of each tuple in the dictionary, to 
single double values with a reference.
+        * 
+        * @param reference The reference values to add to each cell.
+        * @return a double array containing the row sums from this dictionary.
+        */
+       public abstract double[] sumAllRowsToDouble(double[] reference);
+
+       /**
         * Method used as a pre-aggregate of each tuple in the dictionary, to 
single double values.
         * 
         * Note if the number of columns is one the actual dictionaries values 
are simply returned.
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
index f378756..c3faa6c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
@@ -306,6 +306,18 @@ public class Dictionary extends ADictionary {
        }
 
        @Override
+       public double[] sumAllRowsToDouble(double[] reference){
+               final int nCol = reference.length;
+               final int numVals = getNumberOfValues(nCol);
+               double[] ret = new double[numVals + 1];
+               for(int k = 0; k < numVals; k++)
+                       ret[k] = sumRow(k, nCol, reference);
+               for(int i = 0; i < nCol; i++)
+                       ret[numVals] += reference[i];
+               return ret;
+       }
+
+       @Override
        public double[] sumAllRowsToDoubleSq(int nrColumns) {
                // pre-aggregate value tuple
                final int numVals = getNumberOfValues(nrColumns);
@@ -337,6 +349,14 @@ public class Dictionary extends ADictionary {
                return res;
        }
 
+       public double sumRow(int k, int nrColumns, double[] reference) {
+               final int valOff = k * nrColumns;
+               double res = 0.0;
+               for(int i = 0; i < nrColumns; i++)
+                       res += _values[valOff + i] + reference[i];
+               return res;
+       }
+
        @Override
        public double sumRowSq(int k, int nrColumns) {
                final int valOff = k * nrColumns;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
index b5c826a..b69bcf2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
@@ -471,6 +471,52 @@ public class MatrixBlockDictionary extends ADictionary {
        }
 
        @Override
+       public double[] sumAllRowsToDouble(double[] reference){
+               final int nCol = reference.length;
+               final int numVals = _data.getNumRows();
+               final double[] ret = new double[numVals + 1];
+
+               final int finalIndex = numVals;
+               for(int i = 0; i < nCol; i++)
+                       ret[finalIndex] += reference[i];
+
+               if(!_data.isEmpty() && _data.isInSparseFormat()) {
+                       final SparseBlock sb = _data.getSparseBlock();
+                       for(int i = 0; i < numVals; i++) {
+                               if(sb.isEmpty(i))
+                                       ret[i] = ret[finalIndex];
+                               else {
+                                       final int apos = sb.pos(i);
+                                       final int alen = sb.size(i) + apos;
+                                       final int[] aix = sb.indexes(i);
+                                       final double[] avals = sb.values(i);
+                                       int k = apos;
+                                       int j = 0;
+                                       for(; j < _data.getNumColumns() && k < 
alen; j++) {
+                                               final double v = aix[k] == j ? 
avals[k++] + reference[j] : reference[j];
+                                               ret[i] += v ;
+                                       }
+                                       for(; j < _data.getNumColumns(); j++)
+                                               ret[i] += reference[j];
+                               }
+
+                       }
+               }
+               else if(!_data.isEmpty()) {
+                       double[] values = _data.getDenseBlockValues();
+                       int off = 0;
+                       for(int k = 0; k < numVals; k++) {
+                               for(int j = 0; j < _data.getNumColumns(); j++) {
+                                       final double v = values[off++] + 
reference[j];
+                                       ret[k] += v ;
+                               }
+                       }
+               }
+
+               return ret;
+       }
+
+       @Override
        public double[] sumAllRowsToDoubleSq(int nrColumns) {
                final double[] ret = new double[_data.getNumRows()];
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
index 879892a..32ed014 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
@@ -525,4 +525,9 @@ public class QDictionary extends ADictionary {
                double[] newReference) {
                throw new NotImplementedException();
        }
+
+       @Override
+       public double[] sumAllRowsToDouble(double[] reference) {
+               throw new NotImplementedException();
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java
index 373c8af..9e01b84 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java
@@ -34,7 +34,7 @@ import 
org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.CompressionSettings;
 import org.apache.sysds.runtime.compress.DMLCompressionException;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup;
-import org.apache.sysds.runtime.compress.colgroup.AColGroupOffset;
+import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.functionobjects.Builtin;
@@ -111,9 +111,9 @@ public class CLALibCompAgg {
 
                        fillStart(inputMatrix, result, opm);
                        if(requireDecompress)
-                               aggregateUnaryOverlapping(inputMatrix, result, 
opm, indexesIn, inCP);
+                               aggOverlapping(inputMatrix, result, opm, 
indexesIn, inCP);
                        else
-                               
aggregateUnaryNormalCompressedMatrixBlock(inputMatrix, result, opm, blen, 
indexesIn, inCP);
+                               agg(inputMatrix, result, opm, blen, indexesIn, 
inCP);
                }
 
                result.recomputeNonZeros();
@@ -206,8 +206,8 @@ public class CLALibCompAgg {
                return op;
        }
 
-       private static void 
aggregateUnaryNormalCompressedMatrixBlock(CompressedMatrixBlock m, MatrixBlock 
o,
-               AggregateUnaryOperator op, int blen, MatrixIndexes indexesIn, 
boolean inCP) {
+       private static void agg(CompressedMatrixBlock m, MatrixBlock o, 
AggregateUnaryOperator op, int blen,
+               MatrixIndexes indexesIn, boolean inCP) {
 
                int k = op.getNumThreads();
                // replace mean operation with plus.
@@ -218,7 +218,11 @@ public class CLALibCompAgg {
                        aggregateInParallel(m, o, opm, k);
                else {
                        final int nRows = m.getNumRows();
-                       aggregateUnaryOperations(opm, m.getColGroups(), 
o.getDenseBlockValues(), nRows, 0, nRows, m.getNumColumns());
+                       if(op.indexFn instanceof ReduceCol)
+                               agg(opm, m.getColGroups(), 
o.getDenseBlockValues(), nRows, 0, nRows, m.getNumColumns(),
+                                       getPreAgg(opm, m.getColGroups()));
+                       else
+                               agg(opm, m.getColGroups(), 
o.getDenseBlockValues(), nRows, 0, nRows, m.getNumColumns(), null);
                }
 
                postProcessAggregate(m, o, op);
@@ -244,12 +248,13 @@ public class CLALibCompAgg {
                        if(op.indexFn instanceof ReduceCol) {
                                final int blkz = 
CompressionSettings.BITMAP_BLOCK_SZ;
                                final int blklen = Math.max((int) 
Math.ceil((double) r / (k * 2)), blkz);
+                               double[][] preAgg = getPreAgg(op, colGroups);
                                for(int i = 0; i < r; i += blklen)
-                                       tasks.add(new 
UnaryAggregateTask(colGroups, ret, r, i, Math.min(i + blklen, r), op, c, 
false));
+                                       tasks.add(new 
UnaryAggregateTask(colGroups, ret, r, i, Math.min(i + blklen, r), op, c, false, 
preAgg));
                        }
                        else
                                for(List<AColGroup> grp : 
createTaskPartition(colGroups, k))
-                                       tasks.add(new UnaryAggregateTask(grp, 
ret, r, 0, r, op, c, m1.isOverlapping()));
+                                       tasks.add(new UnaryAggregateTask(grp, 
ret, r, 0, r, op, c, m1.isOverlapping(), null));
 
                        List<Future<MatrixBlock>> futures = 
pool.invokeAll(tasks);
                        pool.shutdown();
@@ -275,6 +280,17 @@ public class CLALibCompAgg {
                }
        }
 
+       private static double[][] getPreAgg(AggregateUnaryOperator opm, 
List<AColGroup> groups) {
+               double[][] ret = new double[groups.size()][];
+               for(int i = 0; i < groups.size(); i++) {
+                       AColGroup g = groups.get(i);
+                       if(g instanceof AColGroupCompressed) {
+                               ret[i] = ((AColGroupCompressed) 
g).preAggRows(opm);
+                       }
+               }
+               return ret;
+       }
+
        private static void sumResults(MatrixBlock ret, 
List<Future<MatrixBlock>> futures)
                throws InterruptedException, ExecutionException {
                double val = ret.quickGetValue(0, 0);
@@ -353,7 +369,7 @@ public class CLALibCompAgg {
 
        }
 
-       private static void aggregateUnaryOverlapping(CompressedMatrixBlock m1, 
MatrixBlock ret, AggregateUnaryOperator op,
+       private static void aggOverlapping(CompressedMatrixBlock m1, 
MatrixBlock ret, AggregateUnaryOperator op,
                MatrixIndexes indexesIn, boolean inCP) {
                try {
                        List<Future<MatrixBlock>> rtasks = 
generateUnaryAggregateOverlappingFutures(m1, ret, op);
@@ -448,29 +464,30 @@ public class CLALibCompAgg {
                return grpParts;
        }
 
-       private static void aggregateUnaryOperations(AggregateUnaryOperator op, 
List<AColGroup> groups, double[] ret,
-               int nRows, int rl, int ru, int numColumns) {
-               if(op.indexFn instanceof ReduceCol && op.aggOp.increOp.fn 
instanceof Builtin)
-                       aggregateUnaryBuiltinRowOperation(op, groups, ret, 
nRows, rl, ru, numColumns);
+       private static void agg(AggregateUnaryOperator op, List<AColGroup> 
groups, double[] ret, int nRows, int rl, int ru,
+               int numColumns, double[][] preAgg) {
+               if(op.indexFn instanceof ReduceCol)
+                       aggRow(op, groups, ret, nRows, rl, ru, numColumns, 
preAgg);
                else
-                       aggregateUnaryNormalOperation(op, groups, ret, nRows, 
rl, ru, numColumns);
+                       aggColOrAll(op, groups, ret, nRows, rl, ru, numColumns);
        }
 
-       private static void 
aggregateUnaryNormalOperation(AggregateUnaryOperator op, List<AColGroup> 
groups, double[] ret,
-               int nRows, int rl, int ru, int numColumns) {
+       private static void aggColOrAll(AggregateUnaryOperator op, 
List<AColGroup> groups, double[] ret, int nRows, int rl,
+               int ru, int numColumns) {
                for(AColGroup grp : groups)
                        grp.unaryAggregateOperations(op, ret, nRows, rl, ru);
        }
 
-       private static void 
aggregateUnaryBuiltinRowOperation(AggregateUnaryOperator op, List<AColGroup> 
groups,
-               double[] ret, int nRows, int rl, int ru, int numColumns) {
-
-               for(AColGroup g : groups)
-                       if(g instanceof AColGroupOffset)
-                               throw new NotImplementedException("not 
implemented handling of offset colGroups for row aggregates");
+       private static void aggRow(AggregateUnaryOperator op, List<AColGroup> 
groups, double[] ret, int nRows, int rl,
+               int ru, int numColumns, double[][] preAgg) {
 
-               for(AColGroup grp : groups)
-                       grp.unaryAggregateOperations(op, ret, nRows, rl, ru);
+               for(int i = 0; i < groups.size(); i++) {
+                       AColGroup grp = groups.get(i);
+                       if(grp instanceof AColGroupCompressed)
+                               ((AColGroupCompressed) 
grp).unaryAggregateOperations(op, ret, nRows, rl, ru, preAgg[i]);
+                       else
+                               grp.unaryAggregateOperations(op, ret, nRows, 
rl, ru);
+               }
 
        }
 
@@ -520,16 +537,17 @@ public class CLALibCompAgg {
                private final MatrixBlock _ret;
                private final int _numColumns;
                private final AggregateUnaryOperator _op;
+               private final double[][] _preAgg;
 
                protected UnaryAggregateTask(List<AColGroup> groups, 
MatrixBlock ret, int nRows, int rl, int ru,
-                       AggregateUnaryOperator op, int numColumns, boolean 
overlapping) {
+                       AggregateUnaryOperator op, int numColumns, boolean 
overlapping, double[][] preAgg) {
                        _groups = groups;
                        _op = op;
                        _nRows = nRows;
                        _rl = rl;
                        _ru = ru;
                        _numColumns = numColumns;
-
+                       _preAgg = preAgg;
                        if(_op.indexFn instanceof ReduceAll || (_op.indexFn 
instanceof ReduceRow && overlapping))
                                _ret = genTmpReduceAllOrRow(ret, op);
                        else // colSums / rowSums not overlapping
@@ -539,7 +557,7 @@ public class CLALibCompAgg {
 
                @Override
                public MatrixBlock call() {
-                       aggregateUnaryOperations(_op, _groups, 
_ret.getDenseBlockValues(), _nRows, _rl, _ru, _numColumns);
+                       agg(_op, _groups, _ret.getDenseBlockValues(), _nRows, 
_rl, _ru, _numColumns, _preAgg);
                        return _ret;
                }
        }

Reply via email to