This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 8ecd1fb  [SYSTEMDS-3105] CLA Left MM Shard Common element sum upgrade
8ecd1fb is described below

commit 8ecd1fb5a14d6b82fc41d240ce2f477ef4d859e9
Author: baunsgaard <[email protected]>
AuthorDate: Fri Aug 27 22:02:29 2021 +0200

    [SYSTEMDS-3105] CLA Left MM Shard Common element sum upgrade
    
    This commit expand on the shared element sum from yesterday, improving
    the performance gains further. This together with the workload aware
    improvements today, make LMM with 16 rows on the left
    go 19-21x faster than the default sparse matrix multiplication on 
census_enc.
    makit it take 10 ms per multiplication vs 200 ms with our default.
---
 .../runtime/compress/lib/CLALibLeftMultBy.java     | 222 ++++++++++++---------
 1 file changed, 126 insertions(+), 96 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 3ca657b..10cd2e8 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -37,6 +37,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
 import org.apache.sysds.runtime.compress.utils.LinearAlgebraUtils;
+import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -73,7 +74,7 @@ public class CLALibLeftMultBy {
                if(m2.isEmpty())
                        return ret;
 
-               ret = leftMultByMatrix(m1.getColGroups(), m2, ret, k, 
m1.getNumColumns(), m1.isOverlapping());
+               ret = leftMultByMatrix(m1.getColGroups(), m2, ret, k, 
m1.isOverlapping());
                ret.recomputeNonZeros();
                return ret;
        }
@@ -230,13 +231,13 @@ public class CLALibLeftMultBy {
        }
 
        private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups, 
MatrixBlock that, MatrixBlock ret, int k,
-               int numColumns, boolean overlapping) {
+               boolean overlapping) {
 
                if(that.isEmpty()) {
                        ret.setNonZeros(0);
                        return ret;
                }
-
+               final int numColumnsOut = ret.getNumColumns();
                boolean containsSDC = false;
 
                for(AColGroup g : colGroups) {
@@ -246,7 +247,7 @@ public class CLALibLeftMultBy {
 
                final List<AColGroup> filteredGroups = containsSDC ? new 
ArrayList<>() : colGroups;
                // a constant colgroup summing the default values.
-               final double[] constV = containsSDC ? new double[numColumns] : 
null;
+               final double[] constV = containsSDC ? new double[numColumnsOut] 
: null;
 
                if(containsSDC) {
                        for(AColGroup g : colGroups) {
@@ -260,23 +261,17 @@ public class CLALibLeftMultBy {
                }
 
                ret.allocateDenseBlock();
+               final double[] rowSums = containsSDC ? new 
double[that.getNumRows()] : null;
 
                if(k == 1) {
-                       leftMultByMatrixPrimitive(filteredGroups, that, ret, 
numColumns, 0, that.getNumRows());
-                       if(containsSDC) {
-                               MatrixBlock rowSum = that.rowSum();
-                               if(rowSum.isInSparseFormat())
-                                       rowSum.sparseToDense();
-                               double[] rowSums = rowSum.getDenseBlockValues();
-                               outerProduct(rowSums, constV, 
ret.getDenseBlockValues());
-                       }
+                       leftMultByMatrixPrimitive(filteredGroups, that, ret, 0, 
that.getNumRows(), rowSums);
                }
                else {
                        try {
                                final ExecutorService pool = 
CommonThreadPool.get(k);
                                final ArrayList<Callable<MatrixBlock>> tasks = 
new ArrayList<>();
-                               final int rowBlockSize = that.getNumRows() < 8 
? 1 : Math.min(Math.max(that.getNumRows() / k, 1), 8);
-                               double[] rowSums = null;
+                               final int rowBlockSize = that.getNumRows() <= k 
? 1 : Math.min(Math.max(that.getNumRows() / k * 2, 1),
+                                       8);
 
                                if(overlapping) {
                                        for(AColGroup g : filteredGroups) {
@@ -288,54 +283,56 @@ public class CLALibLeftMultBy {
 
                                        }
                                        List<Future<MatrixBlock>> futures = 
pool.invokeAll(tasks);
-                                       if(containsSDC) {
-                                               MatrixBlock rowSum = 
that.rowSum();
-                                               if(rowSum.isInSparseFormat())
-                                                       rowSum.sparseToDense();
-                                               rowSums = 
rowSum.getDenseBlockValues();
-                                       }
                                        pool.shutdown();
                                        BinaryOperator op = new 
BinaryOperator(Plus.getPlusFnObject());
                                        for(Future<MatrixBlock> future : 
futures)
                                                ret.binaryOperationsInPlace(op, 
future.get());
                                }
                                else {
-                                       if(rowBlockSize > 2) {
+                                       final int numberSplits = Math.max((k / 
(ret.getNumRows() / rowBlockSize)), 1);
+                                       // LOG.error("RowBLockSize:" 
+rowBlockSize + " Splits " + numberSplits);
+                                       if(numberSplits == 1) {
                                                for(int blo = 0; blo < 
that.getNumRows(); blo += rowBlockSize) {
-                                                       tasks.add(new 
LeftMatrixColGroupMultTaskNew(filteredGroups, that, ret, numColumns, blo,
-                                                               Math.min(blo + 
rowBlockSize, that.getNumRows())));
+                                                       tasks.add(new 
LeftMatrixColGroupMultTaskNew(filteredGroups, that, ret, blo,
+                                                               Math.min(blo + 
rowBlockSize, that.getNumRows()), rowSums));
                                                }
                                        }
                                        else {
-                                               List<List<AColGroup>> split = 
split(filteredGroups, Math.max(k / that.getNumRows(), 1));
+                                               List<List<AColGroup>> split = 
split(filteredGroups, numberSplits);
                                                for(int blo = 0; blo < 
that.getNumRows(); blo += rowBlockSize) {
-                                                       for(List<AColGroup> gr 
: split)
-                                                               tasks.add(new 
LeftMatrixColGroupMultTaskNew(gr, that, ret, numColumns, blo,
-                                                                       
Math.min(blo + rowBlockSize, that.getNumRows())));
+                                                       for(int i = 0; i < 
split.size(); i++) {
+                                                               List<AColGroup> 
gr = split.get(i);
+                                                               if(i == 0) {
+                                                                       // the 
first thread also have the responsibility to calculate the som of the left
+                                                                       // hand 
side.
+                                                                       
tasks.add(new LeftMatrixColGroupMultTaskNew(gr, that, ret, blo,
+                                                                               
Math.min(blo + rowBlockSize, that.getNumRows()), rowSums));
+                                                               }
+                                                               else {
+                                                                       
tasks.add(new LeftMatrixColGroupMultTaskNew(gr, that, ret, blo,
+                                                                               
Math.min(blo + rowBlockSize, that.getNumRows()), null));
+                                                               }
+                                                       }
                                                }
                                        }
 
                                        List<Future<MatrixBlock>> futures = 
pool.invokeAll(tasks);
-                                       if(containsSDC) {
-                                               MatrixBlock rowSum = 
that.rowSum();
-                                               if(rowSum.isInSparseFormat())
-                                                       rowSum.sparseToDense();
-                                               rowSums = 
rowSum.getDenseBlockValues();
-                                       }
+
                                        pool.shutdown();
                                        for(Future<MatrixBlock> future : 
futures)
                                                future.get();
                                }
 
-                               if(containsSDC)
-                                       outerProduct(rowSums, constV, 
ret.getDenseBlockValues());
-
                        }
                        catch(InterruptedException | ExecutionException e) {
                                throw new DMLRuntimeException(e);
                        }
                }
 
+               // add the correction layer for the subtracted common values.
+               if(rowSums != null)
+                       outerProduct(rowSums, constV, 
ret.getDenseBlockValues());
+
                ret.recomputeNonZeros();
                return ret;
        }
@@ -396,112 +393,145 @@ public class CLALibLeftMultBy {
                private final MatrixBlock _ret;
                private final int _rl;
                private final int _ru;
-               private final int _numColumns;
+               private final double[] _rowSums;
 
-               protected LeftMatrixColGroupMultTaskNew(List<AColGroup> groups, 
MatrixBlock that, MatrixBlock ret,
-                       int numColumns, int rl, int ru) {
+               protected LeftMatrixColGroupMultTaskNew(List<AColGroup> groups, 
MatrixBlock that, MatrixBlock ret, int rl,
+                       int ru, double[] rowSums) {
                        _groups = groups;
                        _that = that;
                        _ret = ret;
                        _rl = rl;
                        _ru = ru;
-                       _numColumns = numColumns;
+                       _rowSums = rowSums;
                }
 
                @Override
                public MatrixBlock call() {
                        try {
-                               leftMultByMatrixPrimitive(_groups, _that, _ret, 
_numColumns, _rl, _ru);
+                               leftMultByMatrixPrimitive(_groups, _that, _ret, 
_rl, _ru, _rowSums);
                        }
                        catch(Exception e) {
+                               e.printStackTrace();
                                throw new DMLRuntimeException(e);
                        }
                        return _ret;
                }
        }
 
-       private static void leftMultByMatrixPrimitive(List<AColGroup> 
colGroups, MatrixBlock that, MatrixBlock ret,
-               int numColumns, int rl, int ru) {
+       private static void leftMultByMatrixPrimitive(List<AColGroup> 
colGroups, MatrixBlock that, MatrixBlock ret, int rl,
+               int ru, double[] rowSums) {
+               if(that.isInSparseFormat())
+                       leftMultByMatrixPrimitiveSparse(colGroups, that, ret, 
rl, ru, rowSums);
+               else
+                       leftMultByMatrixPrimitiveDense(colGroups, that, ret, 
rl, ru, rowSums);
+       }
+
+       private static void leftMultByMatrixPrimitiveSparse(List<AColGroup> 
colGroups, MatrixBlock that, MatrixBlock ret,
+               int rl, int ru, double[] rowSum) {
 
-               if(that.isInSparseFormat()) {
-                       for(int i = rl; i < ru; i++) {
-                               for(int j = 0; j < colGroups.size(); j++) {
-                                       colGroups.get(j).leftMultByMatrix(that, 
ret, i, i + 1);
+               for(int i = rl; i < ru; i++) {
+                       for(int j = 0; j < colGroups.size(); j++) {
+                               colGroups.get(j).leftMultByMatrix(that, ret, i, 
i + 1);
+                       }
+                       if(rowSum != null) {
+                               final SparseBlock sb = that.getSparseBlock();
+                               if(!sb.isEmpty(i)){
+                                       final int apos = sb.pos(i);
+                                       final int alen = sb.size(i) + apos;
+                                       final double[] aval = sb.values(i);
+                                       for(int j = apos; j < alen; j++)
+                                               rowSum[i] += aval[j];
                                }
                        }
                }
-               else {
-                       // The number of rows to process together
-                       final int rowBlockSize = 1;
-                       // The number of column groups to process together
-                       final int colGroupBlocking = 16;
+       }
 
-                       // Allocate pre Aggregate Array List
-                       final List<MatrixBlock> preAgg = 
populatePreAggregate(colGroupBlocking);
-                       // Allocate a ColGroupValue array for the Column Groups 
of Value Type.
-                       final List<ColGroupValue> ColGroupValues = 
preFilterAndMultiply(colGroups, that, ret, numColumns, rl, ru);
+       private static void leftMultByMatrixPrimitiveDense(List<AColGroup> 
colGroups, MatrixBlock that, MatrixBlock ret,
+               int rl, int ru, double[] rowSum) {
 
-                       // Allocate temporary Result matrix.
-                       MatrixBlock tmpRes = new MatrixBlock(rowBlockSize, 
numColumns, false);
+               final int numColsOut = ret.getNumColumns();
+               // Allocate a ColGroupValue array for the Column Groups of 
Value Type and multiply out any other columns.
+               final List<ColGroupValue> ColGroupValues = 
preFilterAndMultiply(colGroups, that, ret, rl, ru);
 
-                       for(int g = 0; g < ColGroupValues.size(); g += 
colGroupBlocking) {
-                               final int gEnd = Math.min(g + colGroupBlocking, 
colGroups.size());
+               // The number of rows to process together
+               final int rowBlockSize = 1;
+               // The number of column groups to process together
+               // the value should ideally be set so that the colgroups fits 
into cache together with a row block.
+               // currently we only try to avoid having a dangling small 
number of column groups in the last block.
+               final int colGroupBlocking = ColGroupValues.size() % 16 < 4 ? 
20 : 16;
 
-                               // for each column group in the current block 
allocate the preaggregate array.
-                               for(int j = g; j < gEnd && j < 
ColGroupValues.size(); j++) {
-                                       ColGroupValue cg = 
ColGroupValues.get(j);
-                                       int nVals = cg.getNumValues();
-                                       preAgg.get(j % 
colGroupBlocking).reset(rowBlockSize, nVals, false);
-                               }
+               // Allocate pre Aggregate Array List
+               final MatrixBlock[] preAgg = 
populatePreAggregate(colGroupBlocking);
 
-                               // int colBlockSize = 16000;
-                               int colBlockSize = 64000;
-
-                               // For each row block
-                               for(int h = rl; h < ru; h += rowBlockSize) {
-                                       // For each column block
-                                       for(int i = 0; i < 
that.getNumColumns(); i += colBlockSize) {
-                                               // Pre Aggregate each column 
group in block
-                                               for(int j = g; j < gEnd && j < 
ColGroupValues.size(); j++) {
-                                                       
ColGroupValues.get(j).preAggregateDense(that, preAgg.get(j % colGroupBlocking), 
h,
-                                                               Math.min(h + 
rowBlockSize, ru), i, Math.min(i + colBlockSize, that.getNumColumns()));
-                                               }
-                                       }
+               // Allocate temporary Result matrix.
+               MatrixBlock tmpRes = new MatrixBlock(rowBlockSize, numColsOut, 
false);
+
+               // For each column group block
+               for(int g = 0; g < ColGroupValues.size(); g += 
colGroupBlocking) {
+                       final int gEnd = Math.min(g + colGroupBlocking, 
ColGroupValues.size());
+
+                       // For each column group in the current block allocate 
the preaggregate array.
+                       for(int j = g; j < gEnd && j < ColGroupValues.size(); 
j++) {
+                               ColGroupValue cg = ColGroupValues.get(j);
+                               int nVals = cg.getNumValues();
+                               preAgg[j % 
colGroupBlocking].reset(rowBlockSize, nVals, false);
+                       }
+
+                       int colBlockSize = 32000;
+
+                       // For each row block
+                       for(int h = rl; h < ru; h += rowBlockSize) {
+                               // For each column block
+                               final int rowUpper = Math.min(h + rowBlockSize, 
ru);
+                               for(int i = 0; i < that.getNumColumns(); i += 
colBlockSize) {
+                                       final int colUpper = Math.min(i + 
colBlockSize, that.getNumColumns());
+                                       // Pre Aggregate each column group in 
block
                                        for(int j = g; j < gEnd && j < 
ColGroupValues.size(); j++) {
-                                               ColGroupValue vj = 
ColGroupValues.get(j);
-                                               MatrixBlock preAggJ = 
preAgg.get(j % colGroupBlocking);
-                                               preAggJ.recomputeNonZeros();
-                                               tmpRes.reset(rowBlockSize, 
vj.getNumCols(), false);
-                                               MatrixBlock tmp = 
vj.leftMultByPreAggregateMatrix(preAggJ, tmpRes);
-                                               vj.addMatrixToResult(tmp, ret, 
h, Math.min(h + rowBlockSize, ru));
-                                               preAggJ.reset();
+                                               
ColGroupValues.get(j).preAggregateDense(that, preAgg[j % colGroupBlocking], h, 
rowUpper, i,
+                                                       colUpper);
                                        }
+                                       if(rowSum != null) {
+                                               final double[] thatV = 
that.getDenseBlockValues();
+                                               for(int r = h; r < rowUpper; 
r++) {
+                                                       final int rowOff = r * 
that.getNumColumns();
+                                                       for(int c = rowOff + i; 
c < rowOff + colUpper; c++)
+                                                               rowSum[r] += 
thatV[c];
+                                               }
+                                       }
+                               }
+                               // Multiply out the preAggregate to the output 
matrix.
+                               for(int j = g; j < gEnd && j < 
ColGroupValues.size(); j++) {
+                                       ColGroupValue vj = 
ColGroupValues.get(j);
+                                       MatrixBlock preAggJ = preAgg[j % 
colGroupBlocking];
+                                       preAggJ.recomputeNonZeros();
+                                       tmpRes.reset(rowBlockSize, 
vj.getNumCols(), false);
+                                       MatrixBlock tmp = 
vj.leftMultByPreAggregateMatrix(preAggJ, tmpRes);
+                                       vj.addMatrixToResult(tmp, ret, h, 
Math.min(h + rowBlockSize, ru));
+                                       preAggJ.reset();
                                }
                        }
                }
+
        }
 
-       private static List<MatrixBlock> populatePreAggregate(int 
colGroupBlocking) {
-               final List<MatrixBlock> preAgg = new ArrayList<>();
+       private static MatrixBlock[] populatePreAggregate(int colGroupBlocking) 
{
+               final MatrixBlock[] preAgg = new MatrixBlock[colGroupBlocking];
                // poplate the preAgg array.
                for(int j = 0; j < colGroupBlocking; j++) {
-
                        MatrixBlock m = new MatrixBlock(1, 1, false);
                        m.allocateDenseBlock();
-                       preAgg.add(m);
+                       preAgg[j] = m;
                }
                return preAgg;
        }
 
        private static List<ColGroupValue> preFilterAndMultiply(List<AColGroup> 
colGroups, MatrixBlock that,
-               MatrixBlock ret, int numColumns, int rl, int ru) {
-               final List<ColGroupValue> ColGroupValues = new ArrayList<>();
+               MatrixBlock ret, int rl, int ru) {
+               final List<ColGroupValue> ColGroupValues = new 
ArrayList<>(colGroups.size());
                for(int j = 0; j < colGroups.size(); j++) {
                        AColGroup a = colGroups.get(j);
-                       if(a instanceof ColGroupValue) {
-                               ColGroupValue av = (ColGroupValue) a;
-                               ColGroupValues.add(av);
-                       }
+                       if(a instanceof ColGroupValue)
+                               ColGroupValues.add((ColGroupValue) a);
                        else
                                a.leftMultByMatrix(that, ret, rl, ru);
                }

Reply via email to