Repository: systemml
Updated Branches:
  refs/heads/master ab8cccdff -> cca6356f8


[MINOR] Utility to obtain the exact output sparsity of sparse products

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/cca6356f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/cca6356f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/cca6356f

Branch: refs/heads/master
Commit: cca6356f8de49dcb6aeb1f23cefd53930309fedb
Parents: ab8cccd
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sun Oct 21 02:30:25 2018 +0200
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sun Oct 21 02:30:25 2018 +0200

----------------------------------------------------------------------
 .../sysml/hops/estim/EstimationUtils.java       | 56 ++++++++++++++++++++
 1 file changed, 56 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/cca6356f/src/main/java/org/apache/sysml/hops/estim/EstimationUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimationUtils.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimationUtils.java
index d4d30bb..29de2f5 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimationUtils.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimationUtils.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.estim;
 
 import java.util.Arrays;
 
+import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.matrix.data.DenseBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.SparseBlock;
@@ -106,4 +107,59 @@ public abstract class EstimationUtils
                }
                return retNnz;
        }
+       
+       public static long getSparseProductOutputNnz(MatrixBlock m1, 
MatrixBlock m2) {
+               if( !m1.isInSparseFormat() || !m2.isInSparseFormat() )
+                       throw new DMLRuntimeException("Invalid call to sparse 
output nnz estimation.");
+               
+               final int m = m1.getNumRows();
+               final int n2 = m2.getNumColumns();
+               long retNnz = 0;
+               
+               SparseBlock a = m1.getSparseBlock();
+               SparseBlock b = m2.getSparseBlock();
+               
+               SparseRowVector tmpS = new SparseRowVector(1024);
+               double[] tmpD = null;
+                       
+               for( int i=0; i<m; i++ ) {
+                       if( a.isEmpty(i) ) continue;
+                       int alen = a.size(i);
+                       int apos = a.pos(i);
+                       int[] aix = a.indexes(i);
+                       double[] avals = a.values(i);
+                       
+                       //compute number of aggregated non-zeros for input row
+                       int nnz1 = (int) Math.min(UtilFunctions.computeNnz(b, 
aix, apos, alen), n2);
+                       boolean ldense = nnz1 > n2 / 128;
+                       
+                       //perform vector-matrix multiply w/ dense or sparse 
output
+                       if( ldense ) { //init dense tmp row
+                               tmpD = (tmpD == null) ? new double[n2] : tmpD;
+                               Arrays.fill(tmpD, 0);
+                       }
+                       else {
+                               tmpS.setSize(0);
+                       }
+                       for( int k=apos; k<apos+alen; k++ ) {
+                               if( b.isEmpty(aix[k]) ) continue;
+                               int blen = b.size(aix[k]);
+                               int bpos = b.pos(aix[k]);
+                               int[] bix = b.indexes(aix[k]);
+                               double aval = avals[k];
+                               double[] bvals = b.values(aix[k]);
+                               if( ldense ) { //dense aggregation
+                                       for( int j=bpos; j<bpos+blen; j++ )
+                                               tmpD[bix[j]] += aval * bvals[j];
+                               }
+                               else { //sparse aggregation
+                                       for( int j=bpos; j<bpos+blen; j++ )
+                                               tmpS.add(bix[j], aval * 
bvals[j]);
+                               }
+                       }
+                       retNnz += !ldense ? tmpS.size() :
+                               UtilFunctions.computeNnz(tmpD, 0, n2);
+               }
+               return retNnz;
+       }
 }

Reply via email to