This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new fb60577586 [MINOR] Parallel Compressed LMM
fb60577586 is described below

commit fb605775865d2ec0fbcc3aff81975576f8baa5e1
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Mon Oct 30 15:05:17 2023 +0100

    [MINOR] Parallel Compressed LMM
---
 .../runtime/compress/lib/CLALibLeftMultBy.java     | 96 ++++++++++++++++++++--
 .../sysds/runtime/compress/lib/CLALibMMChain.java  | 42 ++++++++++
 .../runtime/compress/lib/CLALibRightMultBy.java    |  4 +-
 3 files changed, 133 insertions(+), 9 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 6029a87d46..30c1109d3a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -32,11 +32,14 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup;
 import org.apache.sysds.runtime.compress.colgroup.APreAgg;
 import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
 import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -45,7 +48,7 @@ import org.apache.sysds.runtime.util.CommonThreadPool;
 public final class CLALibLeftMultBy {
        private static final Log LOG = 
LogFactory.getLog(CLALibLeftMultBy.class.getName());
 
-       private CLALibLeftMultBy(){
+       private CLALibLeftMultBy() {
                // private constructor
        }
 
@@ -139,7 +142,15 @@ public final class CLALibLeftMultBy {
        }
 
        private static MatrixBlock 
leftMultByCompressedTransposedMatrix(CompressedMatrixBlock right,
-               CompressedMatrixBlock left, MatrixBlock ret, int k) {
+               CompressedMatrixBlock left, final MatrixBlock ret, int k) {
+               if(k > 1 && ret.getInMemorySize() < 1000000)
+                       return 
leftMultByCompressedTransposedMatrixParallel(right, left, ret, k);
+               else
+                       return 
leftMultByCompressedTransposedMatrixSingleThread(right, left, ret);
+       }
+
+       private static MatrixBlock 
leftMultByCompressedTransposedMatrixParallel(CompressedMatrixBlock right,
+               CompressedMatrixBlock left, final MatrixBlock ret, int k) {
 
                final int sd = right.getNumRows(); // shared dim
                final int cr = right.getNumColumns();
@@ -149,18 +160,88 @@ public final class CLALibLeftMultBy {
                final List<AColGroup> leftCG = left.getColGroups();
 
                final boolean containsRight = 
CLALibUtils.shouldPreFilter(rightCG);
-               double[] cR = containsRight ? new double[cr] : null;
+               final double[] cR = containsRight ? new double[cr] : null;
                final List<AColGroup> fRight = 
CLALibUtils.filterGroups(rightCG, cR);
 
                final boolean containsLeft = 
CLALibUtils.shouldPreFilter(leftCG);
-               double[] cL = containsLeft ? new double[rl] : null;
+               final double[] cL = containsLeft ? new double[rl] : null;
                final List<AColGroup> fLeft = CLALibUtils.filterGroups(leftCG, 
cL);
 
+               // Force dense output
+               ret.setNonZeros((long) ret.getNumRows() * ret.getNumColumns());
+               ret.allocateDenseBlock();
+
+               final ExecutorService ex = CommonThreadPool.get(k);
+               final List<Future<MatrixBlock>> t = new ArrayList<>();
+
+               for(int j = 0; j < fLeft.size(); j++) {
+                       final int jj = j;
+                       t.add(ex.submit(() -> {
+                               MatrixBlock retT = new 
MatrixBlock(ret.getNumRows(), ret.getNumColumns(), false);
+                               retT.allocateDenseBlock();
+                               for(int i = 0; i < fRight.size(); i++) {
+                                       
fRight.get(i).leftMultByAColGroup(fLeft.get(jj), retT, sd);
+                               }
+                               retT.examSparsity(true);
+                               return retT;
+                       }));
+               }
+
+               try {
+                       final double[] retV = ret.getDenseBlockValues();
+                       if(containsLeft && containsRight)
+                               // if both -- multiply the left and right 
vectors scaling by number of shared dim
+                               outerProductWithScaling(cL, cR, sd, retV);
+                       if(containsLeft) // if left -- multiply left with right 
sum
+                               outerProduct(cL, CLALibUtils.getColSum(fRight, 
cr, sd), retV);
+                       if(containsRight)// if right -- multiply right with 
left sum
+                               outerProduct(CLALibUtils.getColSum(fLeft, rl, 
sd), cR, retV);
+                       for(Future<MatrixBlock> f : t) {
+                               MatrixBlock mb = f.get();
+                               if(!mb.isEmpty()) {
+                                       if(mb.isInSparseFormat())
+                                               
LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new 
BinaryOperator(Plus.getPlusFnObject()));
+                                       else 
if(mb.getDenseBlock().isContiguous())
+                                               
LibMatrixMult.vectAdd(mb.getDenseBlockValues(), retV, 0, 0, retV.length);
+                                       else
+                                               
LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new 
BinaryOperator(Plus.getPlusFnObject()));
+                               }
+                       }
+                       ret.recomputeNonZeros(k);
+               }
+               catch(Exception e) {
+                       throw new DMLCompressionException("Failed parallel Left 
Compressed Mult", e);
+               }
+               finally {
+                       ex.shutdown();
+               }
+               return ret;
+       }
+
+       private static MatrixBlock 
leftMultByCompressedTransposedMatrixSingleThread(CompressedMatrixBlock right,
+               CompressedMatrixBlock left, final MatrixBlock ret) {
+               final int sd = right.getNumRows(); // shared dim
+               final int cr = right.getNumColumns();
+               final int rl = left.getNumColumns();
+
+               final List<AColGroup> rightCG = right.getColGroups();
+               final List<AColGroup> leftCG = left.getColGroups();
+
+               final boolean containsRight = 
CLALibUtils.shouldPreFilter(rightCG);
+               final double[] cR = containsRight ? new double[cr] : null;
+               final List<AColGroup> fRight = 
CLALibUtils.filterGroups(rightCG, cR);
+
+               final boolean containsLeft = 
CLALibUtils.shouldPreFilter(leftCG);
+               final double[] cL = containsLeft ? new double[rl] : null;
+               final List<AColGroup> fLeft = CLALibUtils.filterGroups(leftCG, 
cL);
+
+               // Force dense output
+               ret.setNonZeros((long) ret.getNumRows() * ret.getNumColumns());
+               ret.allocateDenseBlock();
                for(int j = 0; j < fLeft.size(); j++)
                        for(int i = 0; i < fRight.size(); i++)
                                fRight.get(i).leftMultByAColGroup(fLeft.get(j), 
ret, sd);
-
-               double[] retV = ret.getDenseBlockValues();
+               final double[] retV = ret.getDenseBlockValues();
                if(containsLeft && containsRight)
                        // if both -- multiply the left and right vectors 
scaling by number of shared dim
                        outerProductWithScaling(cL, cR, sd, retV);
@@ -169,7 +250,6 @@ public final class CLALibLeftMultBy {
                if(containsRight)// if right -- multiply right with left sum
                        outerProduct(CLALibUtils.getColSum(fLeft, rl, sd), cR, 
retV);
                ret.recomputeNonZeros();
-
                return ret;
        }
 
@@ -218,7 +298,7 @@ public final class CLALibLeftMultBy {
                                LMMParallel(noPreAggGroups, preAggGroups, that, 
ret, null, overlapping, k);
                }
 
-               ret.recomputeNonZeros();
+               ret.recomputeNonZeros(k);
                ret.examSparsity();
                return ret;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
index bc164a5e91..060c736871 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java
@@ -35,6 +35,21 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 
+/**
+ * Support compressed MM chain operation to fuse the following cases :
+ * 
+ * <p>
+ * XtXv == (t(X) %*% (X %*% v))
+ * </p>
+ * 
+ * <p>
+ * XtwXv == (t(X) %*% (w * (X %*% v)))
+ * </p>
+ *
+ * <p>
+ * XtXvy == (t(X) %*% ((X %*% v) - y))
+ * </p>
+ */
 public final class CLALibMMChain {
        static final Log LOG = LogFactory.getLog(CLALibMMChain.class.getName());
 
@@ -42,6 +57,33 @@ public final class CLALibMMChain {
                // private constructor
        }
 
+       /**
+        * Support compressed MM chain operation to fuse the following cases :
+        * 
+        * <p>
+        * XtXv == (t(X) %*% (X %*% v))
+        * </p>
+        * 
+        * <p>
+        * XtwXv == (t(X) %*% (w * (X %*% v)))
+        * </p>
+        *
+        * <p>
+        * XtXvy == (t(X) %*% ((X %*% v) - y))
+        * </p>
+        * 
+        * Note the point of this optimization is that v and w always are 
vectors. This means in practice the all the compute
+        * is faster if the intermediates are exploited.
+        * 
+        * 
+        * @param x     Is the X part of the chain optimized kernel
+        * @param v     Is the mandatory v part of the chain
+        * @param w     Is the optional w port of t the chain
+        * @param out   The output to put the result into. Can also be returned 
and in some cases will not be used.
+        * @param ctype either XtwXv, XtXv or XtXvy
+        * @param k     the parallelization degree
+        * @return The result either in the given output or a new allocation
+        */
        public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock 
v, MatrixBlock w, MatrixBlock out,
                ChainType ctype, int k) {
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
index 39468b0cab..2eef5f9f3f 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
@@ -243,7 +243,9 @@ public final class CLALibRightMultBy {
                catch(InterruptedException | ExecutionException e) {
                        throw new DMLRuntimeException(e);
                }
-               pool.shutdown();
+               finally{
+                       pool.shutdown();
+               }
                return containsNull;
        }
 

Reply via email to