This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 217e31d  [SYSTEMDS-3105] CLA Left MM shared common element sum
217e31d is described below

commit 217e31d87b2ffc745c8a8e1d56cee80e7013aa35
Author: baunsgaard <[email protected]>
AuthorDate: Thu Aug 26 14:20:12 2021 +0200

    [SYSTEMDS-3105] CLA Left MM shared common element sum
    
    This commit modify the Left Matrix Multiplication to share a
    common element sum of all SDC Column groups allowing them to skip
    all their common elements.
    
    This improve performance of LMM on InfiniMnist by 5-10x depending
    on number of rows on the left hand uncompressed matrix
    
    Closes #1376
---
 .../runtime/compress/colgroup/ColGroupFactory.java |  39 +++--
 .../runtime/compress/colgroup/ColGroupSDC.java     |   4 +-
 .../runtime/compress/colgroup/ColGroupValue.java   |   1 +
 .../runtime/compress/colgroup/offset/AOffset.java  |   7 +-
 .../runtime/compress/lib/CLALibLeftMultBy.java     | 163 +++++++++++++++------
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  14 +-
 6 files changed, 163 insertions(+), 65 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 1d7b97d..f42e382 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -45,6 +45,7 @@ import 
org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
 import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
 import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
 import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
+import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory.CostType;
 import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
 import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorExact;
 import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
@@ -214,34 +215,40 @@ public final class ColGroupFactory {
                CompressedSizeInfoColGroup cg, int[] colIndexes) {
                try {
                        final int nrUniqueEstimate = cg.getNumVals();
-                       final CompressionType estimatedBestCompressionType = 
cg.getBestCompressionType();
+                       CompressionType estimatedBestCompressionType = 
cg.getBestCompressionType();
+                       
+                       if(estimatedBestCompressionType == CompressionType.SDC 
&& cs.costComputationType == CostType.W_TREE) {
+                               if(cg.getCompressionSize(CompressionType.DDC) * 
3 < cg.getCompressionSize(CompressionType.SDC))
+                                       estimatedBestCompressionType = 
CompressionType.DDC;
+                       }
+
                        if(estimatedBestCompressionType == 
CompressionType.UNCOMPRESSED) {
                                // shortcut if uncompressed
                                return new ColGroupUncompressed(colIndexes, in, 
cs.transposed);
                        }
                        else if(estimatedBestCompressionType == 
CompressionType.SDC && colIndexes.length == 1 &&
                                in.isInSparseFormat() && cs.transposed) {
-                               // shortcut for creating SDC!
-                               // throw new NotImplementedException();
+
                                return compressSDCZero(in.getSparseBlock(), 
colIndexes, in.getNumColumns(),
                                        tmp.getDblCountMap(nrUniqueEstimate));
                        }
                        else {
-                               ABitmap ubm;
-                               if(colIndexes.length > 1)
-                                       ubm = 
BitmapEncoder.extractBitmapMultiColumns(colIndexes, in, cs.transposed,
-                                               
tmp.getDblArrayMap(nrUniqueEstimate));
-                               else
-                                       ubm = 
BitmapEncoder.extractBitmap(colIndexes, in, cs.transposed, nrUniqueEstimate);
-
-                               CompressedSizeEstimator estimator = new 
CompressedSizeEstimatorExact(in, cs);
+                               final int numRows = cs.transposed ? 
in.getNumColumns() : in.getNumRows();
 
-                               CompressedSizeInfoColGroup sizeInfo = new 
CompressedSizeInfoColGroup(
-                                       
estimator.estimateCompressedColGroupSize(ubm, colIndexes), 
cs.validCompressions, ubm);
+                               if(colIndexes.length > 1) {
+                                       final ABitmap ubm = 
BitmapEncoder.extractBitmapMultiColumns(colIndexes, in, cs.transposed,
+                                               
tmp.getDblArrayMap(nrUniqueEstimate));
+                                       CompressedSizeEstimator estimator = new 
CompressedSizeEstimatorExact(in, cs);
+                                       CompressedSizeInfoColGroup sizeInfo = 
new CompressedSizeInfoColGroup(
+                                               
estimator.estimateCompressedColGroupSize(ubm, colIndexes), 
cs.validCompressions, ubm);
+                                       return compress(colIndexes, numRows, 
ubm, estimatedBestCompressionType, cs, in,
+                                               sizeInfo.getTupleSparsity());
+                               }
+                               else {
+                                       final ABitmap ubm = 
BitmapEncoder.extractBitmap(colIndexes, in, cs.transposed, nrUniqueEstimate);
+                                       return compress(colIndexes, numRows, 
ubm, estimatedBestCompressionType, cs, in, 1.0);
+                               }
 
-                               int numRows = cs.transposed ? 
in.getNumColumns() : in.getNumRows();
-                               return compress(colIndexes, numRows, ubm, 
sizeInfo.getBestCompressionType(cs), cs, in,
-                                       sizeInfo.getTupleSparsity());
                        }
 
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index 23adae9..ffe38b1 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -280,7 +280,7 @@ public class ColGroupSDC extends ColGroupValue {
                final double[] mV = m.getDenseBlockValues();
                final double[] preAV = preAgg.getDenseBlockValues();
                final int numVals = getNumValues();
-               AIterator itStart = _indexes.getIterator(cl);
+               final AIterator itStart = _indexes.getIterator(cl);
                AIterator it = null;
                for(int rowLeft = rl, offOut = 0; rowLeft < ru; rowLeft++, 
offOut += numVals) {
                        final int offLeft = rowLeft * _numRows;
@@ -300,7 +300,7 @@ public class ColGroupSDC extends ColGroupValue {
                                preAV[def] += mV[offLeft + rc];
                        }
                }
-               if(it != null)
+               if(it != null && cu < m.getNumColumns())
                        _indexes.cacheIterator(it, cu + 1);
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
index 7307bc2..6844095 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
@@ -887,6 +887,7 @@ public abstract class ColGroupValue extends 
ColGroupCompressed implements Clonea
                if(right.length != rightColumns.length)
                        throw new DMLCompressionException(
                                "Error right not equal length " + right.length 
+ " " + rightColumns.length);
+                               
                for(int row = 0; row < leftRows.length; row++) {
                        final int outputRowOffset = leftRows[row] * outCols;
                        final double vLeft = left[row];
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
index 7f42240..ac359b3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
@@ -59,8 +59,11 @@ public abstract class AOffset implements Serializable {
        public AIterator getIterator(int row) {
                if(skipIterators != null) {
                        Map<Integer, AIterator> sk = skipIterators.get();
-                       if(sk != null && sk.containsKey(row))
-                               return sk.get(row).clone();
+                       if(sk != null && sk.containsKey(row)){
+                               AIterator it = sk.get(row);
+                               if(it != null)
+                                       return it.clone();
+                       }
                }
 
                AIterator it = getIterator();
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 04ccdc2..3ca657b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -33,6 +33,8 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
 import org.apache.sysds.runtime.compress.utils.LinearAlgebraUtils;
 import org.apache.sysds.runtime.functionobjects.Plus;
@@ -92,7 +94,7 @@ public class CLALibLeftMultBy {
                final boolean overlapping = cmb.isOverlapping();
                List<AColGroup> groups = cmb.getColGroups();
                result.allocateDenseBlock();
-               
+
                if(overlapping) {
                        LOG.warn("Inefficient TSMM with overlapping matrix 
could be implemented multi-threaded but is not yet.");
                        leftMultByCompressedTransposedMatrix(groups, groups, 
result);
@@ -110,7 +112,7 @@ public class CLALibLeftMultBy {
                                        final AColGroup g = groups.get(i);
                                        tasks.add(new 
LeftMultByCompressedTransposedMatrixTask(groups, g, result, i, groups.size()));
                                }
-                               
+
                                for(Future<Object> tret : pool.invokeAll(tasks))
                                        tret.get();
                                pool.shutdown();
@@ -228,26 +230,56 @@ public class CLALibLeftMultBy {
        }
 
        private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups, 
MatrixBlock that, MatrixBlock ret, int k,
-               int numColumns,  boolean overlapping) {
+               int numColumns, boolean overlapping) {
 
                if(that.isEmpty()) {
                        ret.setNonZeros(0);
                        return ret;
                }
 
+               boolean containsSDC = false;
+
+               for(AColGroup g : colGroups) {
+                       if(g instanceof ColGroupSDC || g instanceof 
ColGroupSDCSingle)
+                               containsSDC = true;
+               }
+
+               final List<AColGroup> filteredGroups = containsSDC ? new 
ArrayList<>() : colGroups;
+               // a constant colgroup summing the default values.
+               final double[] constV = containsSDC ? new double[numColumns] : 
null;
+
+               if(containsSDC) {
+                       for(AColGroup g : colGroups) {
+                               if(g instanceof ColGroupSDC)
+                                       filteredGroups.add(((ColGroupSDC) 
g).extractCommon(constV));
+                               else if(g instanceof ColGroupSDCSingle)
+                                       filteredGroups.add(((ColGroupSDCSingle) 
g).extractCommon(constV));
+                               else
+                                       filteredGroups.add(g);
+                       }
+               }
+
                ret.allocateDenseBlock();
 
-               if(k == 1)
-                       leftMultByMatrixPrimitive(colGroups, that, ret, 
numColumns, 0, that.getNumRows());
+               if(k == 1) {
+                       leftMultByMatrixPrimitive(filteredGroups, that, ret, 
numColumns, 0, that.getNumRows());
+                       if(containsSDC) {
+                               MatrixBlock rowSum = that.rowSum();
+                               if(rowSum.isInSparseFormat())
+                                       rowSum.sparseToDense();
+                               double[] rowSums = rowSum.getDenseBlockValues();
+                               outerProduct(rowSums, constV, 
ret.getDenseBlockValues());
+                       }
+               }
                else {
                        try {
                                final ExecutorService pool = 
CommonThreadPool.get(k);
                                final ArrayList<Callable<MatrixBlock>> tasks = 
new ArrayList<>();
                                final int rowBlockSize = that.getNumRows() < 8 
? 1 : Math.min(Math.max(that.getNumRows() / k, 1), 8);
-                               // final int rowBlockSize = 4;
+                               double[] rowSums = null;
 
                                if(overlapping) {
-                                       for(AColGroup g : colGroups) {
+                                       for(AColGroup g : filteredGroups) {
                                                MatrixBlock tmpRet = new 
MatrixBlock(ret.getNumRows(), ret.getNumColumns(), false);
                                                tmpRet.allocateDenseBlock();
                                                for(int blo = 0; blo < 
that.getNumRows(); blo += rowBlockSize)
@@ -256,6 +288,12 @@ public class CLALibLeftMultBy {
 
                                        }
                                        List<Future<MatrixBlock>> futures = 
pool.invokeAll(tasks);
+                                       if(containsSDC) {
+                                               MatrixBlock rowSum = 
that.rowSum();
+                                               if(rowSum.isInSparseFormat())
+                                                       rowSum.sparseToDense();
+                                               rowSums = 
rowSum.getDenseBlockValues();
+                                       }
                                        pool.shutdown();
                                        BinaryOperator op = new 
BinaryOperator(Plus.getPlusFnObject());
                                        for(Future<MatrixBlock> future : 
futures)
@@ -264,36 +302,40 @@ public class CLALibLeftMultBy {
                                else {
                                        if(rowBlockSize > 2) {
                                                for(int blo = 0; blo < 
that.getNumRows(); blo += rowBlockSize) {
-                                                       tasks.add(new 
LeftMatrixColGroupMultTaskNew(colGroups, that, ret, numColumns, blo,
+                                                       tasks.add(new 
LeftMatrixColGroupMultTaskNew(filteredGroups, that, ret, numColumns, blo,
                                                                Math.min(blo + 
rowBlockSize, that.getNumRows())));
                                                }
                                        }
                                        else {
-
-                                               List<List<AColGroup>> split = 
split(colGroups, Math.max(k / that.getNumRows(), 1));
+                                               List<List<AColGroup>> split = 
split(filteredGroups, Math.max(k / that.getNumRows(), 1));
                                                for(int blo = 0; blo < 
that.getNumRows(); blo += rowBlockSize) {
                                                        for(List<AColGroup> gr 
: split)
                                                                tasks.add(new 
LeftMatrixColGroupMultTaskNew(gr, that, ret, numColumns, blo,
                                                                        
Math.min(blo + rowBlockSize, that.getNumRows())));
                                                }
-
-                                               // for(AColGroup g : colGroups)
-                                               // for(int blo = 0; blo < 
that.getNumRows(); blo += rowBlockSize)
-                                               // tasks.add(new 
LeftMatrixColGroupMultTaskOld(g, that, ret, blo,
-                                               // Math.min(blo + rowBlockSize, 
that.getNumRows()), maxNumValues));
                                        }
 
                                        List<Future<MatrixBlock>> futures = 
pool.invokeAll(tasks);
+                                       if(containsSDC) {
+                                               MatrixBlock rowSum = 
that.rowSum();
+                                               if(rowSum.isInSparseFormat())
+                                                       rowSum.sparseToDense();
+                                               rowSums = 
rowSum.getDenseBlockValues();
+                                       }
                                        pool.shutdown();
                                        for(Future<MatrixBlock> future : 
futures)
                                                future.get();
                                }
 
+                               if(containsSDC)
+                                       outerProduct(rowSums, constV, 
ret.getDenseBlockValues());
+
                        }
                        catch(InterruptedException | ExecutionException e) {
                                throw new DMLRuntimeException(e);
                        }
                }
+
                ret.recomputeNonZeros();
                return ret;
        }
@@ -311,6 +353,16 @@ public class CLALibLeftMultBy {
                return ret;
        }
 
+       private static void outerProduct(final double[] leftRowSum, final 
double[] rightColumnSum, final double[] result) {
+               for(int row = 0; row < leftRowSum.length; row++) {
+                       final int offOut = rightColumnSum.length * row;
+                       final double vLeft = leftRowSum[row];
+                       for(int col = 0; col < rightColumnSum.length; col++) {
+                               result[offOut + col] += vLeft * 
rightColumnSum[col];
+                       }
+               }
+       }
+
        private static class LeftMatrixColGroupMultTaskOld implements 
Callable<MatrixBlock> {
                private final AColGroup _group;
                private final MatrixBlock _that;
@@ -379,47 +431,44 @@ public class CLALibLeftMultBy {
                        }
                }
                else {
-                       List<ColGroupValue> v = new ArrayList<>();
-                       int rowBlockSize = 1;
-                       List<MatrixBlock> preAgg = new ArrayList<>();
-                       int colGroupBlocking = 16;
-                       for(int j = 0; j < colGroupBlocking; j++) {
-                               MatrixBlock m = new MatrixBlock(1, 1, false);
-                               m.allocateDenseBlock();
-                               preAgg.add(m);
-                       }
+                       // The number of rows to process together
+                       final int rowBlockSize = 1;
+                       // The number of column groups to process together
+                       final int colGroupBlocking = 16;
+
+                       // Allocate pre Aggregate Array List
+                       final List<MatrixBlock> preAgg = 
populatePreAggregate(colGroupBlocking);
+                       // Allocate a ColGroupValue array for the Column Groups 
of Value Type.
+                       final List<ColGroupValue> ColGroupValues = 
preFilterAndMultiply(colGroups, that, ret, numColumns, rl, ru);
 
+                       // Allocate temporary Result matrix.
                        MatrixBlock tmpRes = new MatrixBlock(rowBlockSize, 
numColumns, false);
 
-                       for(int j = 0; j < colGroups.size(); j++) {
-                               AColGroup a = colGroups.get(j);
-                               if(a instanceof ColGroupValue) {
-                                       ColGroupValue av = (ColGroupValue) a;
-                                       v.add(av);
-                               }
-                               else
-                                       a.leftMultByMatrix(that, ret, rl, ru);
-                       }
-                       Collections.sort(v, 
Comparator.comparing(AColGroup::getNumValues).reversed());
-                       // LOG.error(v);
-                       for(int g = 0; g < v.size(); g += colGroupBlocking) {
+                       for(int g = 0; g < ColGroupValues.size(); g += 
colGroupBlocking) {
                                final int gEnd = Math.min(g + colGroupBlocking, 
colGroups.size());
-                               for(int j = g; j < gEnd && j < v.size(); j++) {
-                                       ColGroupValue cg = v.get(j);
-                                       preAgg.get(j % 
colGroupBlocking).reset(rowBlockSize, cg.getNumValues(), false);
+
+                               // for each column group in the current block 
allocate the preaggregate array.
+                               for(int j = g; j < gEnd && j < 
ColGroupValues.size(); j++) {
+                                       ColGroupValue cg = 
ColGroupValues.get(j);
+                                       int nVals = cg.getNumValues();
+                                       preAgg.get(j % 
colGroupBlocking).reset(rowBlockSize, nVals, false);
                                }
+
                                // int colBlockSize = 16000;
                                int colBlockSize = 64000;
 
+                               // For each row block
                                for(int h = rl; h < ru; h += rowBlockSize) {
+                                       // For each column block
                                        for(int i = 0; i < 
that.getNumColumns(); i += colBlockSize) {
-                                               for(int j = g; j < gEnd && j < 
v.size(); j++) {
-                                                       
v.get(j).preAggregateDense(that, preAgg.get(j % colGroupBlocking), h,
+                                               // Pre Aggregate each column 
group in block
+                                               for(int j = g; j < gEnd && j < 
ColGroupValues.size(); j++) {
+                                                       
ColGroupValues.get(j).preAggregateDense(that, preAgg.get(j % colGroupBlocking), 
h,
                                                                Math.min(h + 
rowBlockSize, ru), i, Math.min(i + colBlockSize, that.getNumColumns()));
                                                }
                                        }
-                                       for(int j = g; j < gEnd && j < 
v.size(); j++) {
-                                               ColGroupValue vj = v.get(j);
+                                       for(int j = g; j < gEnd && j < 
ColGroupValues.size(); j++) {
+                                               ColGroupValue vj = 
ColGroupValues.get(j);
                                                MatrixBlock preAggJ = 
preAgg.get(j % colGroupBlocking);
                                                preAggJ.recomputeNonZeros();
                                                tmpRes.reset(rowBlockSize, 
vj.getNumCols(), false);
@@ -431,4 +480,32 @@ public class CLALibLeftMultBy {
                        }
                }
        }
+
+       private static List<MatrixBlock> populatePreAggregate(int 
colGroupBlocking) {
+               final List<MatrixBlock> preAgg = new ArrayList<>();
+               // poplate the preAgg array.
+               for(int j = 0; j < colGroupBlocking; j++) {
+
+                       MatrixBlock m = new MatrixBlock(1, 1, false);
+                       m.allocateDenseBlock();
+                       preAgg.add(m);
+               }
+               return preAgg;
+       }
+
+       private static List<ColGroupValue> preFilterAndMultiply(List<AColGroup> 
colGroups, MatrixBlock that,
+               MatrixBlock ret, int numColumns, int rl, int ru) {
+               final List<ColGroupValue> ColGroupValues = new ArrayList<>();
+               for(int j = 0; j < colGroups.size(); j++) {
+                       AColGroup a = colGroups.get(j);
+                       if(a instanceof ColGroupValue) {
+                               ColGroupValue av = (ColGroupValue) a;
+                               ColGroupValues.add(av);
+                       }
+                       else
+                               a.leftMultByMatrix(that, ret, rl, ru);
+               }
+               Collections.sort(ColGroupValues, 
Comparator.comparing(AColGroup::getNumValues).reversed());
+               return ColGroupValues;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 8d2d672..160f8e9 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -975,13 +975,23 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
        }
 
        /**
-        * Wrapper method for reduceall-colSum of a matrix.
+        * Wrapper method for single threaded reduceall-colSum of a matrix.
         * 
         * @return A new MatrixBlock containing the column sums of this matrix.
         */
        public MatrixBlock colSum() {
                AggregateUnaryOperator op = 
InstructionUtils.parseBasicAggregateUnaryOperator("uack+", 1);
-               return aggregateUnaryOperations(op, null, 1000, null);
+               return aggregateUnaryOperations(op, null, 1000, null, true);
+       }
+
+       /**
+        * Wrapper method for single threaded reduceall-rowSum of a matrix.
+        * 
+        * @return A new MatrixBlock containing the row sums of this matrix.
+        */
+       public MatrixBlock rowSum(){
+               AggregateUnaryOperator op = 
InstructionUtils.parseBasicAggregateUnaryOperator("uark+", 1);
+               return aggregateUnaryOperations(op, null, 1000, null, true);
        }
 
        /**

Reply via email to